summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc2
-rw-r--r--.buildkite/hooks/post-command24
-rw-r--r--.buildkite/hooks/pre-command11
-rw-r--r--.buildkite/pipeline.yaml107
-rw-r--r--.github/pull_request_template.md5
-rw-r--r--.github/workflows/build.yml13
-rw-r--r--.github/workflows/go.yml46
-rw-r--r--.github/workflows/issue_reviver.yml2
-rw-r--r--.github/workflows/labeler.yml1
-rw-r--r--.github/workflows/stale.yml2
-rw-r--r--.gitignore7
-rw-r--r--BUILD11
-rw-r--r--Makefile430
-rw-r--r--WORKSPACE14
-rw-r--r--g3doc/proposals/runtime_dedicate_os_thread.md188
-rw-r--r--go.mod4
-rw-r--r--go.sum6
-rw-r--r--images/BUILD10
-rw-r--r--images/Makefile107
-rw-r--r--images/agent/Dockerfile12
-rw-r--r--images/agent/README.md7
-rw-r--r--images/basic/ping4test/Dockerfile7
-rw-r--r--[-rwxr-xr-x]images/basic/ping4test/ping4.sh (renamed from tools/vm/zone.sh)10
-rw-r--r--images/basic/ping6test/Dockerfile7
-rw-r--r--images/basic/ping6test/ping6.sh32
-rw-r--r--images/benchmarks/absl/Dockerfile.x86_64 (renamed from images/benchmarks/absl/Dockerfile)1
-rw-r--r--images/benchmarks/hey/Dockerfile13
-rw-r--r--images/benchmarks/runsc/Dockerfile.x86_64 (renamed from images/benchmarks/runsc/Dockerfile)1
-rw-r--r--images/default/Dockerfile8
-rw-r--r--images/runtimes/go1.12/Dockerfile.x86_64 (renamed from images/runtimes/go1.12/Dockerfile)0
-rw-r--r--nogo.yaml141
-rw-r--r--pkg/abi/linux/fcntl.go2
-rw-r--r--pkg/abi/linux/fuse.go8
-rw-r--r--pkg/abi/linux/sem.go35
-rw-r--r--pkg/abi/linux/socket.go14
-rw-r--r--pkg/coverage/coverage.go130
-rw-r--r--pkg/cpuid/cpuid.go11
-rw-r--r--pkg/cpuid/cpuid_x86.go11
-rw-r--r--pkg/crypto/BUILD12
-rw-r--r--pkg/crypto/crypto.go (renamed from pkg/sleep/empty.s)5
-rw-r--r--pkg/crypto/crypto_stdlib.go32
-rw-r--r--pkg/flipcall/BUILD3
-rw-r--r--pkg/flipcall/packet_window_mmap_amd64.go (renamed from pkg/flipcall/packet_window_mmap.go)0
-rw-r--r--pkg/flipcall/packet_window_mmap_arm64.go (renamed from pkg/syncevent/waiter_asm_unsafe.go)13
-rw-r--r--pkg/goid/BUILD1
-rw-r--r--pkg/log/json.go14
-rw-r--r--pkg/log/log.go2
-rw-r--r--pkg/merkletree/merkletree.go10
-rw-r--r--pkg/p9/client.go4
-rw-r--r--pkg/p9/client_file.go25
-rw-r--r--pkg/p9/handlers.go81
-rw-r--r--pkg/p9/p9test/client_test.go15
-rw-r--r--pkg/p9/server.go14
-rw-r--r--pkg/p9/transport_test.go16
-rw-r--r--pkg/pool/pool.go1
-rw-r--r--pkg/refsvfs2/BUILD2
-rw-r--r--pkg/refsvfs2/refs_template.go9
-rw-r--r--pkg/safemem/block_unsafe.go20
-rw-r--r--pkg/seccomp/seccomp.go2
-rw-r--r--pkg/segment/test/set_functions.go1
-rw-r--r--pkg/sentry/arch/arch.go15
-rw-r--r--pkg/sentry/arch/arch_state_x86.go17
-rw-r--r--pkg/sentry/arch/signal.go39
-rw-r--r--pkg/sentry/control/pprof.go2
-rw-r--r--pkg/sentry/control/state.go1
-rw-r--r--pkg/sentry/fdimport/fdimport.go1
-rw-r--r--pkg/sentry/fs/BUILD2
-rw-r--r--pkg/sentry/fs/copy_up.go13
-rw-r--r--pkg/sentry/fs/copy_up_test.go2
-rw-r--r--pkg/sentry/fs/file.go47
-rw-r--r--pkg/sentry/fs/filetest/filetest.go4
-rw-r--r--pkg/sentry/fs/fs.go2
-rw-r--r--pkg/sentry/fs/gofer/BUILD2
-rw-r--r--pkg/sentry/fs/gofer/attr.go2
-rw-r--r--pkg/sentry/fs/gofer/file.go31
-rw-r--r--pkg/sentry/fs/gofer/inode.go2
-rw-r--r--pkg/sentry/fs/inode.go6
-rw-r--r--pkg/sentry/fs/proc/sys.go1
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD2
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go26
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_control.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_test.go10
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/directory.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/file.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go42
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go12
-rw-r--r--pkg/sentry/fsimpl/fuse/request_response.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go26
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go74
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go43
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go24
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go23
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go6
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go1
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go5
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go67
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go178
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go494
-rw-r--r--pkg/sentry/fsmetric/BUILD10
-rw-r--r--pkg/sentry/fsmetric/fsmetric.go83
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go7
-rw-r--r--pkg/sentry/kernel/fasync/BUILD2
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go96
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go11
-rw-r--r--pkg/sentry/kernel/kernel.go50
-rw-r--r--pkg/sentry/kernel/ptrace.go4
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go115
-rw-r--r--pkg/sentry/kernel/shm/BUILD3
-rw-r--r--pkg/sentry/kernel/signal.go4
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go4
-rw-r--r--pkg/sentry/kernel/syslog.go6
-rw-r--r--pkg/sentry/kernel/task_exit.go8
-rw-r--r--pkg/sentry/kernel/task_signals.go16
-rw-r--r--pkg/sentry/memmap/memmap.go2
-rw-r--r--pkg/sentry/mm/aio_context.go79
-rw-r--r--pkg/sentry/mm/aio_context_state.go4
-rw-r--r--pkg/sentry/mm/lifecycle.go2
-rw-r--r--pkg/sentry/mm/mm_test.go43
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go62
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go14
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go40
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go1
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go7
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go2
-rw-r--r--pkg/sentry/platform/ring0/BUILD11
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go1
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s196
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go12
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go14
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s110
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go14
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go16
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go2
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go78
-rw-r--r--pkg/sentry/socket/hostinet/socket.go160
-rw-r--r--pkg/sentry/socket/netlink/socket.go12
-rw-r--r--pkg/sentry/socket/netstack/netstack.go705
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go6
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/stack.go30
-rw-r--r--pkg/sentry/socket/socket.go249
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go27
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go1
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go78
-rw-r--r--pkg/sentry/socket/unix/unix.go15
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go2
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/socket.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go18
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go2
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/epoll.go10
-rw-r--r--pkg/sentry/vfs/file_description.go86
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go36
-rw-r--r--pkg/sentry/vfs/save_restore.go25
-rw-r--r--pkg/sentry/vfs/vfs.go3
-rw-r--r--pkg/sentry/watchdog/watchdog.go1
-rw-r--r--pkg/shim/v1/proc/process.go1
-rw-r--r--pkg/shim/v1/shim/BUILD1
-rw-r--r--pkg/shim/v1/shim/shim.go (renamed from pkg/sleep/commit_asm.go)13
-rw-r--r--pkg/shim/v1/utils/utils.go1
-rw-r--r--pkg/shim/v2/BUILD1
-rw-r--r--pkg/shim/v2/service.go86
-rw-r--r--pkg/sleep/BUILD4
-rw-r--r--pkg/sleep/commit_amd64.s35
-rw-r--r--pkg/sleep/commit_arm64.s38
-rw-r--r--pkg/sleep/commit_noasm.go33
-rw-r--r--pkg/sleep/sleep_unsafe.go27
-rw-r--r--pkg/state/tests/integer_test.go26
-rw-r--r--pkg/sync/BUILD35
-rw-r--r--pkg/sync/atomicptrmaptest/BUILD57
-rw-r--r--pkg/sync/atomicptrmaptest/atomicptrmap.go (renamed from tools/vm/test.cc)15
-rw-r--r--pkg/sync/atomicptrmaptest/atomicptrmap_test.go635
-rw-r--r--pkg/sync/checklocks_off_unsafe.go18
-rw-r--r--pkg/sync/checklocks_on_unsafe.go108
-rw-r--r--pkg/sync/generic_atomicptr_unsafe.go (renamed from pkg/sync/atomicptr_unsafe.go)4
-rw-r--r--pkg/sync/generic_atomicptrmap_unsafe.go503
-rw-r--r--pkg/sync/generic_seqatomic_unsafe.go (renamed from pkg/sync/seqatomic_unsafe.go)21
-rw-r--r--pkg/sync/goyield_go113_unsafe.go18
-rw-r--r--pkg/sync/goyield_unsafe.go (renamed from pkg/sync/spin_unsafe.go)8
-rw-r--r--pkg/sync/memmove_unsafe.go28
-rw-r--r--pkg/sync/mutex_test.go4
-rw-r--r--pkg/sync/mutex_unsafe.go51
-rw-r--r--pkg/sync/norace_unsafe.go11
-rw-r--r--pkg/sync/race_amd64.s (renamed from pkg/syncevent/waiter_amd64.s)17
-rw-r--r--pkg/sync/race_arm64.s (renamed from pkg/syncevent/waiter_arm64.s)17
-rw-r--r--pkg/sync/race_unsafe.go6
-rw-r--r--pkg/sync/runtime_unsafe.go129
-rw-r--r--pkg/sync/rwmutex_test.go2
-rw-r--r--pkg/sync/rwmutex_unsafe.go152
-rw-r--r--pkg/sync/seqcount.go34
-rw-r--r--pkg/sync/seqcount_test.go53
-rw-r--r--pkg/syncevent/BUILD4
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_unsafe.go59
-rw-r--r--pkg/syserr/host_linux.go2
-rw-r--r--pkg/syserr/netstack.go93
-rw-r--r--pkg/tcpip/checker/checker.go390
-rw-r--r--pkg/tcpip/header/BUILD5
-rw-r--r--pkg/tcpip/header/icmpv6.go13
-rw-r--r--pkg/tcpip/header/igmp.go181
-rw-r--r--pkg/tcpip/header/igmp_test.go110
-rw-r--r--pkg/tcpip/header/ipv4.go209
-rw-r--r--pkg/tcpip/header/ipv4_test.go179
-rw-r--r--pkg/tcpip/header/ipv6.go15
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go345
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go356
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go42
-rw-r--r--pkg/tcpip/header/mld.go103
-rw-r--r--pkg/tcpip/header/mld_test.go61
-rw-r--r--pkg/tcpip/header/ndp_options.go2
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go18
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go18
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go11
-rw-r--r--pkg/tcpip/link/loopback/loopback.go17
-rw-r--r--pkg/tcpip/link/muxed/BUILD1
-rw-r--r--pkg/tcpip/link/muxed/injectable.go8
-rw-r--r--pkg/tcpip/link/nested/BUILD1
-rw-r--r--pkg/tcpip/link/nested/nested.go6
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go4
-rw-r--r--pkg/tcpip/link/pipe/pipe.go7
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD1
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go12
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go17
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go29
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go64
-rw-r--r--pkg/tcpip/link/tun/device.go4
-rw-r--r--pkg/tcpip/link/waitable/BUILD2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go6
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go12
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD2
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go10
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go148
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go225
-rw-r--r--pkg/tcpip/network/ip/BUILD26
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go676
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go877
-rw-r--r--pkg/tcpip/network/ip_test.go129
-rw-r--r--pkg/tcpip/network/ipv4/BUILD7
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go345
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go215
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go255
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go277
-rw-r--r--pkg/tcpip/network/ipv6/BUILD18
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go57
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go194
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go285
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go310
-rw-r--r--pkg/tcpip/network/ipv6/mld.go262
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go297
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go351
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go140
-rw-r--r--pkg/tcpip/network/multicast_group_test.go1261
-rw-r--r--pkg/tcpip/network/testutil/testutil.go15
-rw-r--r--pkg/tcpip/socketops.go341
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go179
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go22
-rw-r--r--pkg/tcpip/stack/forwarding_test.go16
-rw-r--r--pkg/tcpip/stack/ndp_test.go259
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go16
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go7
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go27
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go144
-rw-r--r--pkg/tcpip/stack/nic.go103
-rw-r--r--pkg/tcpip/stack/registration.go25
-rw-r--r--pkg/tcpip/stack/route.go200
-rw-r--r--pkg/tcpip/stack/stack.go147
-rw-r--r--pkg/tcpip/stack/stack_test.go217
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go14
-rw-r--r--pkg/tcpip/stack/transport_test.go34
-rw-r--r--pkg/tcpip/tcpip.go214
-rw-r--r--pkg/tcpip/tcpip_test.go44
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go193
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go25
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go82
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go45
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go94
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD6
-rw-r--r--pkg/tcpip/transport/tcp/accept.go6
-rw-r--r--pkg/tcpip/transport/tcp/connect.go42
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go288
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go50
-rw-r--r--pkg/tcpip/transport/tcp/reno_recovery.go67
-rw-r--r--pkg/tcpip/transport/tcp/sack_recovery.go120
-rw-r--r--pkg/tcpip/transport/tcp/segment.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go321
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go42
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go162
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go14
-rw-r--r--pkg/tcpip/transport/udp/BUILD2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go375
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go9
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go326
-rw-r--r--pkg/test/criutil/criutil.go12
-rw-r--r--pkg/test/dockerutil/container.go6
-rw-r--r--pkg/test/dockerutil/exec.go5
-rw-r--r--pkg/test/testutil/testutil.go58
-rw-r--r--pkg/usermem/usermem.go2
-rw-r--r--pkg/waiter/waiter.go11
-rw-r--r--pkg/waiter/waiter_test.go20
-rw-r--r--runsc/boot/compat.go2
-rw-r--r--runsc/boot/controller.go72
-rw-r--r--runsc/boot/filter/config.go60
-rw-r--r--runsc/boot/fs.go18
-rw-r--r--runsc/boot/loader.go87
-rw-r--r--runsc/boot/network.go33
-rw-r--r--runsc/boot/vfs.go70
-rw-r--r--runsc/cgroup/cgroup.go32
-rw-r--r--runsc/cgroup/cgroup_test.go5
-rw-r--r--runsc/cli/main.go4
-rw-r--r--runsc/cmd/BUILD2
-rw-r--r--runsc/cmd/do.go2
-rw-r--r--runsc/cmd/exec.go2
-rw-r--r--runsc/cmd/symbolize.go91
-rw-r--r--runsc/console/console.go4
-rw-r--r--runsc/container/BUILD5
-rw-r--r--runsc/container/console_test.go205
-rw-r--r--runsc/container/container.go37
-rw-r--r--runsc/container/multi_container_test.go53
-rw-r--r--runsc/fsgofer/BUILD3
-rw-r--r--runsc/fsgofer/fsgofer.go79
-rw-r--r--runsc/fsgofer/fsgofer_test.go217
-rw-r--r--runsc/sandbox/network.go27
-rw-r--r--runsc/sandbox/sandbox.go30
-rw-r--r--test/benchmarks/BUILD11
-rw-r--r--test/benchmarks/base/BUILD9
-rw-r--r--test/benchmarks/database/BUILD13
-rw-r--r--test/benchmarks/database/database.go15
-rw-r--r--test/benchmarks/database/redis_test.go33
-rw-r--r--test/benchmarks/defs.bzl14
-rw-r--r--test/benchmarks/fs/BUILD6
-rw-r--r--test/benchmarks/fs/bazel_test.go18
-rw-r--r--test/benchmarks/harness/harness.go2
-rw-r--r--test/benchmarks/media/BUILD10
-rw-r--r--test/benchmarks/media/ffmpeg_test.go12
-rw-r--r--test/benchmarks/media/media.go15
-rw-r--r--test/benchmarks/ml/BUILD10
-rw-r--r--test/benchmarks/ml/ml.go15
-rw-r--r--test/benchmarks/ml/tensorflow_test.go12
-rw-r--r--test/benchmarks/network/BUILD74
-rw-r--r--test/benchmarks/network/httpd_test.go71
-rw-r--r--test/benchmarks/network/iperf_test.go27
-rw-r--r--test/benchmarks/network/network.go69
-rw-r--r--test/benchmarks/network/nginx_test.go75
-rw-r--r--test/benchmarks/network/node_test.go16
-rw-r--r--test/benchmarks/network/ruby_test.go16
-rw-r--r--test/benchmarks/network/static_server.go87
-rw-r--r--test/benchmarks/tools/iperf.go2
-rw-r--r--test/benchmarks/tools/redis.go6
-rw-r--r--test/cmd/test_app/fds.go5
-rw-r--r--test/e2e/integration_test.go49
-rw-r--r--test/fuse/BUILD5
-rw-r--r--test/fuse/linux/BUILD12
-rw-r--r--test/fuse/linux/mount_test.cc41
-rw-r--r--test/iptables/filter_output.go2
-rw-r--r--test/iptables/iptables_test.go4
-rw-r--r--test/packetdrill/BUILD9
-rw-r--r--test/packetdrill/defs.bzl6
-rwxr-xr-xtest/packetdrill/packetdrill_test.sh13
-rw-r--r--test/packetimpact/runner/BUILD1
-rw-r--r--test/packetimpact/runner/defs.bzl22
-rw-r--r--test/packetimpact/runner/dut.go447
-rw-r--r--test/packetimpact/runner/packetimpact_test.go2
-rw-r--r--test/packetimpact/testbench/BUILD1
-rw-r--r--test/packetimpact/testbench/connections.go135
-rw-r--r--test/packetimpact/testbench/dut.go36
-rw-r--r--test/packetimpact/testbench/layers.go49
-rw-r--r--test/packetimpact/testbench/rawsockets.go33
-rw-r--r--test/packetimpact/testbench/testbench.go147
-rw-r--r--test/packetimpact/tests/BUILD20
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go5
-rw-r--r--test/packetimpact/tests/icmpv6_param_problem_test.go5
-rw-r--r--test/packetimpact/tests/ipv4_fragment_reassembly_test.go70
-rw-r--r--test/packetimpact/tests/ipv4_id_uniqueness_test.go6
-rw-r--r--test/packetimpact/tests/ipv6_fragment_icmp_error_test.go27
-rw-r--r--test/packetimpact/tests/ipv6_fragment_reassembly_test.go62
-rw-r--r--test/packetimpact/tests/ipv6_unknown_options_action_test.go5
-rw-r--r--test/packetimpact/tests/tcp_cork_mss_test.go5
-rw-r--r--test/packetimpact/tests/tcp_handshake_window_size_test.go5
-rw-r--r--test/packetimpact/tests/tcp_linger_test.go6
-rw-r--r--test/packetimpact/tests/tcp_network_unreachable_test.go19
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go5
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go5
-rw-r--r--test/packetimpact/tests/tcp_paws_mechanism_test.go5
-rw-r--r--test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go8
-rw-r--r--test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go8
-rw-r--r--test/packetimpact/tests/tcp_rcv_buf_space_test.go5
-rw-r--r--test/packetimpact/tests/tcp_reordering_test.go45
-rw-r--r--test/packetimpact/tests/tcp_retransmits_test.go5
-rw-r--r--test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go8
-rw-r--r--test/packetimpact/tests/tcp_synrcvd_reset_test.go5
-rw-r--r--test/packetimpact/tests/tcp_synsent_reset_test.go39
-rw-r--r--test/packetimpact/tests/tcp_timewait_reset_test.go5
-rw-r--r--test/packetimpact/tests/tcp_unacc_seq_ack_test.go11
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go5
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go5
-rw-r--r--test/packetimpact/tests/tcp_zero_receive_window_test.go125
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go5
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_test.go5
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go5
-rw-r--r--test/packetimpact/tests/udp_any_addr_recv_unicast_test.go7
-rw-r--r--test/packetimpact/tests/udp_discard_mcast_source_addr_test.go12
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go10
-rw-r--r--test/packetimpact/tests/udp_recv_mcast_bcast_test.go31
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go17
-rw-r--r--test/perf/BUILD3
-rw-r--r--test/root/BUILD14
-rw-r--r--test/root/crictl_test.go4
-rw-r--r--test/runner/defs.bzl45
-rw-r--r--test/runtimes/BUILD12
-rw-r--r--test/runtimes/runner/lib/lib.go26
-rw-r--r--test/runtimes/runner/main.go14
-rw-r--r--test/syscalls/BUILD67
-rw-r--r--test/syscalls/linux/BUILD48
-rw-r--r--test/syscalls/linux/chown.cc13
-rw-r--r--test/syscalls/linux/exceptions.cc10
-rw-r--r--test/syscalls/linux/fcntl.cc486
-rw-r--r--test/syscalls/linux/kill.cc6
-rw-r--r--test/syscalls/linux/open_create.cc8
-rw-r--r--test/syscalls/linux/proc.cc29
-rw-r--r--test/syscalls/linux/raw_socket.cc45
-rw-r--r--test/syscalls/linux/semaphore.cc162
-rw-r--r--test/syscalls/linux/signalfd.cc2
-rw-r--r--test/syscalls/linux/socket_generic.cc67
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc60
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc55
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc84
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound.cc131
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound.h29
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc32
-rw-r--r--test/syscalls/linux/tcp_socket.cc160
-rw-r--r--test/syscalls/linux/udp_socket.cc63
-rw-r--r--tools/bazel.mk260
-rw-r--r--tools/bazel_gazelle.patch24
-rw-r--r--tools/bazeldefs/BUILD18
-rw-r--r--tools/bazeldefs/defs.bzl43
-rw-r--r--tools/bazeldefs/go.bzl4
-rw-r--r--tools/bigquery/bigquery.go7
-rw-r--r--tools/checkescape/test1/test1.go13
-rw-r--r--tools/defs.bzl6
-rwxr-xr-xtools/go_branch.sh4
-rw-r--r--tools/go_generics/defs.bzl4
-rw-r--r--tools/go_generics/generics.go2
-rw-r--r--tools/go_marshal/gomarshal/generator.go50
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go4
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go4
-rw-r--r--tools/images.mk169
-rw-r--r--tools/installers/BUILD10
-rwxr-xr-xtools/installers/containerd.sh17
-rw-r--r--tools/nogo/filter/main.go9
-rw-r--r--tools/parsers/go_parser_test.go22
-rw-r--r--tools/vm/BUILD63
-rw-r--r--tools/vm/README.md48
-rwxr-xr-xtools/vm/build.sh117
-rw-r--r--tools/vm/defs.bzl202
-rwxr-xr-xtools/vm/execute.sh160
-rwxr-xr-xtools/vm/ubuntu1604/10_core.sh43
-rwxr-xr-xtools/vm/ubuntu1604/15_gcloud.sh50
-rwxr-xr-xtools/vm/ubuntu1604/20_bazel.sh38
-rwxr-xr-xtools/vm/ubuntu1604/30_docker.sh64
-rwxr-xr-xtools/vm/ubuntu1604/40_kokoro.sh72
-rw-r--r--tools/vm/ubuntu1604/BUILD7
-rw-r--r--tools/vm/ubuntu1804/BUILD7
-rw-r--r--website/blog/README.md62
-rw-r--r--website/blog/index.html5
507 files changed, 19752 insertions, 8520 deletions
diff --git a/.bazelrc b/.bazelrc
index e2848ef07..47c26843d 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -42,5 +42,5 @@ build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-defa
build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604
build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604
build:remote --crosstool_top=@rbe_default//cc:toolchain
-build:remote --jobs=300
+build:remote --jobs=100
build:remote --remote_timeout=3600
diff --git a/.buildkite/hooks/post-command b/.buildkite/hooks/post-command
new file mode 100644
index 000000000..ce3111f3c
--- /dev/null
+++ b/.buildkite/hooks/post-command
@@ -0,0 +1,24 @@
+# Upload test logs on failure, if there are any.
+if [[ "${BUILDKITE_COMMAND_EXIT_STATUS}" -ne "0" ]]; then
+ declare log_count=0
+ for log in $(make testlogs 2>/dev/null | sort | uniq); do
+ buildkite-agent artifact upload "${log}"
+ log_count=$((${log_count}+1))
+ # N.B. If *all* tests fail due to some common cause, then we will
+ # end up spending way too much time uploading logs. Instead, we just
+ # upload the first 100 and stop. That is hopefully enough to debug.
+ if [[ "${log_count}" -ge 100 ]]; then
+ echo "Only uploaded first 100 failures; skipping the rest."
+ break
+ fi
+ done
+ # Attempt to clear the cache and shut down.
+ make clean || echo "make clean failed with code $?"
+ make bazel-shutdown || echo "make bazel-shutdown failed with code $?"
+fi
+
+# Kill any running containers (clear state).
+CONTAINERS="$(docker ps -q)"
+if ! [[ -z "${CONTAINERS}" ]]; then
+ docker container kill ${CONTAINERS} 2>/dev/null || true
+fi \ No newline at end of file
diff --git a/.buildkite/hooks/pre-command b/.buildkite/hooks/pre-command
new file mode 100644
index 000000000..7d277202b
--- /dev/null
+++ b/.buildkite/hooks/pre-command
@@ -0,0 +1,11 @@
+# Setup for parallelization with PARTITION and TOTAL_PARTITIONS.
+export PARTITION=${BUILDKITE_PARALLEL_JOB:-0}
+PARTITION=$((${PARTITION}+1)) # 1-indexed, but PARALLEL_JOB is 0-indexed.
+export TOTAL_PARTITIONS=${BUILDKITE_PARALLEL_JOB_COUNT:-1}
+
+# Ensure Docker has experimental enabled.
+EXPERIMENTAL=$(sudo docker version --format='{{.Server.Experimental}}')
+if test "${EXPERIMENTAL}" != "true"; then
+ make sudo TARGETS=//runsc:runsc ARGS="install --experimental=true"
+ sudo systemctl restart docker
+fi
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
new file mode 100644
index 000000000..79a80d9c8
--- /dev/null
+++ b/.buildkite/pipeline.yaml
@@ -0,0 +1,107 @@
+_templates:
+ common: &common
+ timeout_in_minutes: 30
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 10
+ - exit_status: "*"
+ limit: 2
+
+steps:
+ # Run basic smoke tests before preceding to other tests.
+ - <<: *common
+ label: ":fire: Smoke tests"
+ command: make smoke-tests
+ - wait
+
+ # Check that the Go branch builds.
+ - <<: *common
+ label: ":golang: Go branch"
+ commands:
+ - make go
+ - git checkout go && git clean -f
+ - go build ./...
+
+ # Release workflow.
+ - <<: *common
+ label: ":ship: Release tests"
+ commands: make release
+
+ # Basic unit tests.
+ - <<: *common
+ label: ":test_tube: Unit tests"
+ command: make unit-tests
+
+ # All system call tests.
+ - <<: *common
+ label: ":toolbox: System call tests"
+ command: make syscall-tests
+ parallelism: 20
+
+ # Integration tests.
+ - <<: *common
+ label: ":parachute: FUSE tests"
+ command: make fuse-tests
+ - <<: *common
+ label: ":docker: Docker tests"
+ command: make docker-tests
+ - <<: *common
+ label: ":goggles: Overlay tests"
+ command: make overlay-tests
+ - <<: *common
+ label: ":safety_pin: Host network tests"
+ command: make hostnet-tests
+ - <<: *common
+ label: ":satellite: SWGSO tests"
+ command: make swgso-tests
+ - <<: *common
+ label: ":coffee: Do tests"
+ command: make do-tests
+ - <<: *common
+ label: ":person_in_lotus_position: KVM tests"
+ command: make kvm-tests
+ - <<: *common
+ label: ":docker: Containerd 1.3.9 tests"
+ command: make containerd-test-1.3.9
+ - <<: *common
+ label: ":docker: Containerd 1.4.3 tests"
+ command: make containerd-test-1.4.3
+
+ # Check the website builds.
+ - <<: *common
+ label: ":earth_americas: Website tests"
+ command: make website-build
+
+ # Networking tests.
+ - <<: *common
+ label: ":table_tennis_paddle_and_ball: IPTables tests"
+ command: make iptables-tests
+ - <<: *common
+ label: ":construction_worker: Packetdrill tests"
+ command: make packetdrill-tests
+ - <<: *common
+ label: ":hammer: Packetimpact tests"
+ command: make packetimpact-tests
+
+ # Runtime tests.
+ - <<: *common
+ label: ":php: PHP runtime tests"
+ command: make php7.3.6-runtime-tests_vfs2
+ parallelism: 10
+ - <<: *common
+ label: ":java: Java runtime tests"
+ command: make java11-runtime-tests_vfs2
+ parallelism: 40
+ - <<: *common
+ label: ":golang: Go runtime tests"
+ command: make go1.12-runtime-tests_vfs2
+ parallelism: 10
+ - <<: *common
+ label: ":node: NodeJS runtime tests"
+ command: make nodejs12.4.0-runtime-tests_vfs2
+ parallelism: 10
+ - <<: *common
+ label: ":python: Python runtime tests"
+ command: make python3.7.3-runtime-tests_vfs2
+ parallelism: 10
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
deleted file mode 100644
index 264b4e9fa..000000000
--- a/.github/pull_request_template.md
+++ /dev/null
@@ -1,5 +0,0 @@
-* [ ] Have you followed the guidelines in [CONTRIBUTING.md](../blob/master/CONTRIBUTING.md)?
-* [ ] Have you formatted and linted your code?
-* [ ] Have you added relevant tests?
-* [ ] Have you added appropriate Fixes & Updates references?
-* [ ] If yes, please erase all these lines!
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index e28e46352..3be10b9bb 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,13 +1,18 @@
+# This workflow builds the source code, extracts nogo annotations and
+# posts them to GitHub, if applicable. This leverages the fact that the
+# workflow token has appropriate permissions to do so, and attempts to
+# leverage the GitHub workflow caches.
+#
+# This workflow also generates the build badge that is referred to by
+# the main README.
name: "Build"
on:
push:
branches:
- master
- - feature/**
pull_request:
branches:
- - master
- - feature/**
+ - "**"
jobs:
default:
@@ -22,7 +27,7 @@ jobs:
${{ runner.os }}-bazel-
- run: make
- run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..."
- - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ nogo"
+ - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ -path=bazel-out/ nogo"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REPOSITORY: ${{ github.repository }}
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 3a6a592d1..c87ab22ef 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -1,12 +1,12 @@
+# This workflow generates the Go branch. Note that this does not test the Go
+# branch, as this is rolled into the main continuous integration pipeline. This
+# workflow simply generates and pushes the branch, as long as appropriate
+# permissions are available.
name: "Go"
on:
push:
branches:
- master
- pull_request:
- branches:
- - master
- - feature/**
jobs:
generate:
@@ -19,20 +19,13 @@ jobs:
else
echo ::set-output name=has_token::false
fi
- - run: |
- jq -nc '{"state": "pending", "context": "go tests"}' | \
- curl -sL -X POST -d @- \
- -H "Content-Type: application/json" \
- -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
- "${{ github.event.pull_request.statuses_url }}"
- if: github.event_name == 'pull_request'
- uses: actions/checkout@v2
- if: github.event_name == 'push' && steps.setup.outputs.has_token == 'true'
+ if: steps.setup.outputs.has_token == 'true'
with:
fetch-depth: 0
token: '${{ secrets.GO_TOKEN }}'
- uses: actions/checkout@v2
- if: github.event_name == 'pull_request' || steps.setup.outputs.has_token != 'true'
+ if: steps.setup.outputs.has_token != 'true'
with:
fetch-depth: 0
- uses: actions/setup-go@v2
@@ -50,32 +43,7 @@ jobs:
key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }}
restore-keys: |
${{ runner.os }}-bazel-
- # Create gopath to merge the changes. The first execution will create
- # symlinks to the cache, e.g. bazel-bin. Once the cache is setup, delete
- # old gopath files that may exist from previous runs (and could contain
- # files that are now deleted). Then run gopath again for good.
- - run: |
- make build TARGETS="//:gopath"
- rm -rf bazel-bin/gopath
- make build TARGETS="//:gopath"
- - run: tools/go_branch.sh
- - run: git checkout go && git clean -f
- - run: go build ./...
- - if: github.event_name == 'push'
+ - run: make go
run: |
git remote add upstream "https://github.com/${{ github.repository }}"
git push upstream go:go
- - if: ${{ success() && github.event_name == 'pull_request' }}
- run: |
- jq -nc '{"state": "success", "context": "go tests"}' | \
- curl -sL -X POST -d @- \
- -H "Content-Type: application/json" \
- -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
- "${{ github.event.pull_request.statuses_url }}"
- - if: ${{ failure() && github.event_name == 'pull_request' }}
- run: |
- jq -nc '{"state": "failure", "context": "go tests"}' | \
- curl -sL -X POST -d @- \
- -H "Content-Type: application/json" \
- -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
- "${{ github.event.pull_request.statuses_url }}"
diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml
index c53185620..f03b814c9 100644
--- a/.github/workflows/issue_reviver.yml
+++ b/.github/workflows/issue_reviver.yml
@@ -1,3 +1,5 @@
+# This workflow revives issues that are still referenced in the code, and may
+# have been accidentally closed or marked stale.
name: "Issue reviver"
on:
schedule:
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index c09f7eb36..a53fdb3e9 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -1,3 +1,4 @@
+# Labeler labels incoming pull requests.
name: "Labeler"
on:
- pull_request
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 0b31fecf5..be10c5bc4 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -1,3 +1,5 @@
+# The stale workflow closes stale issues and pull requests, unless specific
+# tags have been applied in order to keep them open.
name: "Close stale issues"
on:
schedule:
diff --git a/.gitignore b/.gitignore
index a56f6ebcd..a2a3fd508 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,7 @@
# Generated bazel symlinks.
-/bazel-* \ No newline at end of file
+/bazel-*
+# Generated build event file.
+/.build_events.json
+# Generated repository.
+/repo
+/repo.key \ No newline at end of file
diff --git a/BUILD b/BUILD
index 0791f9fb4..7cabede3c 100644
--- a/BUILD
+++ b/BUILD
@@ -67,12 +67,15 @@ build_test(
"//test/benchmarks/base:startup_test",
"//test/benchmarks/base:size_test",
"//test/benchmarks/base:sysbench_test",
- "//test/benchmarks/database:database_test",
+ "//test/benchmarks/database:redis_test",
"//test/benchmarks/fs:bazel_test",
"//test/benchmarks/fs:fio_test",
- "//test/benchmarks/media:media_test",
- "//test/benchmarks/ml:ml_test",
- "//test/benchmarks/network:network_test",
+ "//test/benchmarks/media:ffmpeg_test",
+ "//test/benchmarks/ml:tensorflow_test",
+ "//test/benchmarks/network:httpd_test",
+ "//test/benchmarks/network:nginx_test",
+ "//test/benchmarks/network:node_test",
+ "//test/benchmarks/network:ruby_test",
],
)
diff --git a/Makefile b/Makefile
index 79d8fd791..8565102b3 100644
--- a/Makefile
+++ b/Makefile
@@ -14,19 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Helpful pretty-printer.
-MAKEBANNER := \033[1;34mmake\033[0m
-submake = echo -e '$(MAKEBANNER) $1' >&2; $(MAKE) $1
-
-# Described below.
-OPTIONS :=
-STARTUP_OPTIONS :=
-TARGETS := //runsc
-ARGS :=
-
default: runsc
.PHONY: default
+# Header for debugging (used by other macros).
+header = echo --- $(1) >&2
+
+# Make hacks.
+EMPTY :=
+SPACE := $(EMPTY) $(EMPTY)
+
## usage: make <target>
## or
## make <build|test|copy|run|sudo> STARTUP_OPTIONS="..." OPTIONS="..." TARGETS="..." ARGS="..."
@@ -38,7 +35,6 @@ default: runsc
## requirements.
##
## There are common arguments that may be passed to targets. These are:
-## STARTUP_OPTIONS - Bazel startup options.
## OPTIONS - Build or test options.
## TARGETS - The bazel targets.
## ARGS - Arguments for run or sudo.
@@ -49,7 +45,7 @@ default: runsc
## make build OPTIONS="" TARGETS="//runsc"'
##
help: ## Shows all targets and help from the Makefile (this message).
- @grep --no-filename -E '^([a-z.A-Z_-]+:.*?|)##' $(MAKEFILE_LIST) | \
+ @grep --no-filename -E '^([a-z.A-Z_%-]+:.*?|)##' $(MAKEFILE_LIST) | \
awk 'BEGIN {FS = "(:.*?|)## ?"}; { \
if (length($$1) > 0) { \
printf " \033[36m%-20s\033[0m %s\n", $$1, $$2; \
@@ -57,17 +53,34 @@ help: ## Shows all targets and help from the Makefile (this message).
printf "%s\n", $$2; \
} \
}'
+
build: ## Builds the given $(TARGETS) with the given $(OPTIONS). E.g. make build TARGETS=runsc
-test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test
-copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp
-run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version
-sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v
-.PHONY: help build test copy run sudo
+ @$(call build,$(OPTIONS) $(TARGETS))
+.PHONY: build
+
+test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test
+ @$(call build,$(OPTIONS) $(TARGETS))
+.PHONY: test
+
+copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp
+ @$(call copy,$(TARGETS),$(DESTINATION))
+.PHONY: copy
+
+run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version
+ @$(call run,$(TARGETS),$(ARGS))
+.PHONY: run
+
+sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v
+ @$(call sudo,$(TARGETS),$(ARGS))
+.PHONY: sudo
+
+# Load image helpers.
+include tools/images.mk
# Load all bazel wrappers.
#
# This file should define the basic "build", "test", "run" and "sudo" rules, in
-# addition to the $(BRANCH_NAME) variable.
+# addition to the $(BRANCH_NAME) and $(BUILD_ROOTS) variables.
ifneq (,$(wildcard tools/google.mk))
include tools/google.mk
else
@@ -75,32 +88,74 @@ include tools/bazel.mk
endif
##
-## Docker image targets.
-##
-## Images used by the tests must also be built and available locally.
-## The canonical test targets defined below will automatically load
-## relevant images. These can be loaded or built manually via these
-## targets.
+## Development helpers and tooling.
##
-## (*) Note that you may provide an ARCH parameter in order to build
-## and load images from an alternate archiecture (using qemu). When
-## bazel is run as a server, this has the effect of running an full
-## cross-architecture chain, and can produce cross-compiled binaries.
+## These targets faciliate local development by automatically
+## installing and configuring a runtime. Several variables may
+## be used here to tweak the installation:
+## RUNTIME - The name of the installed runtime (default: branch).
+## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME).
+## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc).
+## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs).
+## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%).
##
-define images
-$(1)-%: ## Image tool: $(1) a given image (also may use 'all-images').
- @$(call submake,-C images $$@)
-endef
-rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'.
-$(eval $(call images,rebuild))
-push-...: ## Push the given image. Also may use 'push-all-images'.
-$(eval $(call images,push))
-pull-...: ## Pull the given image. Also may use 'pull-all-images'.
-$(eval $(call images,pull))
-load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'.
-$(eval $(call images,load))
-list-images: ## List all available images.
- @$(call submake, -C images $$@)
+ifeq (,$(BRANCH_NAME))
+RUNTIME := runsc
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
+else
+RUNTIME := $(BRANCH_NAME)
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
+endif
+RUNTIME_BIN := $(RUNTIME_DIR)/runsc
+RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs
+RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+
+$(RUNTIME_BIN): # See below.
+ @mkdir -p "$(RUNTIME_DIR)"
+ @$(call copy,//runsc,$(RUNTIME_BIN))
+.PHONY: $(RUNTIME_BIN) # Real file, but force rebuild.
+
+# Configure helpers for below.
+configure_noreload = \
+ $(call header,CONFIGURE $(1) → $(RUNTIME_BIN) $(2)); \
+ sudo $(RUNTIME_BIN) install --experimental=true --runtime="$(1)" -- --debug-log "$(RUNTIME_LOGS)" $(2) && \
+ sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)"
+reload_docker = \
+ sudo systemctl reload docker && \
+ if test -f /etc/docker/daemon.json; then \
+ sudo chmod 0755 /etc/docker && \
+ sudo chmod 0644 /etc/docker/daemon.json; \
+ fi
+configure = $(call configure_noreload,$(1),$(2)) && $(reload_docker)
+
+# Helpers for above. Requires $(RUNTIME_BIN) dependency.
+install_runtime = $(call configure,$(RUNTIME),$(1) --TESTONLY-test-name-env=RUNSC_TEST_NAME)
+test_runtime = $(call test,--test_arg=--runtime=$(RUNTIME) $(PARTITIONS) $(1))
+
+refresh: $(RUNTIME_BIN) ## Updates the runtime binary.
+.PHONY: refresh
+
+dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo.
+ @$(call configure_noreload,$(RUNTIME),--net-raw)
+ @$(call configure_noreload,$(RUNTIME)-d,--net-raw --debug --strace --log-packets)
+ @$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile)
+ @$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2)
+ @$(call reload_docker)
+.PHONY: dev
+
+nogo: ## Surfaces all nogo findings.
+ @$(call build,--build_tag_filters nogo //...)
+ @$(call run,//tools/github $(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo)
+.PHONY: nogo
+
+go: ## Builds the Go branch.
+ @$(call clean)
+ @$(call build,//:gopath)
+ @tools/go_branch.sh
+
+gazelle: ## Runs gazelle to update WORKSPACE.
+ @$(call run,//:gazelle update-repos -from_file=go.mod -prune)
+.PHONY: gazelle
##
## Canonical build and test targets.
@@ -109,24 +164,32 @@ list-images: ## List all available images.
## convenient entrypoints for testing changes. If you're adding a
## new subsystem or workflow, consider adding a new target here.
##
+## Some targets support a PARTITION (1-indexed) and TOTAL_PARTITIONS
+## environment variables for high-level test sharding. Unlike most
+## other variables, these are sourced from the environment.
+##
+PARTITION ?= 1
+TOTAL_PARTITIONS ?= 1
+PARTITIONS := --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)
+
runsc: ## Builds the runsc binary.
- @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc")
+ @$(call build,-c opt //runsc)
.PHONY: runsc
debian: ## Builds the debian packages.
- @$(call submake,build OPTIONS="-c opt" TARGETS="//debian:debian")
+ @$(call build,-c opt //debian:debian)
.PHONY: debian
smoke-tests: ## Runs a simple smoke test after build runsc.
- @$(call submake,run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true")
+ @$(call run,//runsc,--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true)
.PHONY: smoke-tests
fuse-tests:
- @$(call submake,test OPTIONS="--test_tag_filters fuse" TARGETS="test/fuse/...")
+ @$(call test,--test_tag_filters=fuse $(PARTITIONS) test/fuse/...)
.PHONY: fuse-tests
unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc.
- @$(call submake,test TARGETS="pkg/... runsc/... tools/...")
+ @$(call test,pkg/... runsc/... tools/...)
.PHONY: unit-tests
tests: ## Runs all unit tests and syscall tests.
@@ -135,120 +198,99 @@ tests: unit-tests syscall-tests
integration-tests: ## Run all standard integration tests.
integration-tests: docker-tests overlay-tests hostnet-tests swgso-tests
-integration-tests: do-tests kvm-tests containerd-test-1.3.4
+integration-tests: do-tests kvm-tests containerd-test-1.3.9
.PHONY: integration-tests
network-tests: ## Run all networking integration tests.
network-tests: iptables-tests packetdrill-tests packetimpact-tests
.PHONY: network-tests
-# Standard integration targets.
-INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test
-
syscall-%-tests:
- @$(call submake,test OPTIONS="--test_tag_filters runsc_$*" TARGETS="test/syscalls/...")
+ @$(call test,--test_tag_filters=runsc_$* $(PARTITIONS) test/syscalls/...)
syscall-native-tests:
- @$(call submake,test OPTIONS="--test_tag_filters native" TARGETS="test/syscalls/...")
+ @$(call test,--test_tag_filters=native $(PARTITIONS) test/syscalls/...)
.PHONY: syscall-native-tests
syscall-tests: ## Run all system call tests.
- @$(call submake,test TARGETS="test/syscalls/...")
+ @$(call test,$(PARTITIONS) test/syscalls/...)
-%-runtime-tests: load-runtimes_%
-ifeq ($(PARTITION),)
- @$(eval PARTITION := 1)
-endif
-ifeq ($(TOTAL_PARTITIONS),)
- @$(eval TOTAL_PARTITIONS := 1)
-endif
- @$(call submake,install-runtime)
- @$(call submake,test-runtime OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*")
+%-runtime-tests: load-runtimes_% $(RUNTIME_BIN)
+ @$(call install_runtime,) # Ensure flags are cleared.
+ @$(call test_runtime,--test_timeout=10800 //test/runtimes:$*)
-%-runtime-tests_vfs2: load-runtimes_%
-ifeq ($(PARTITION),)
- @$(eval PARTITION := 1)
-endif
-ifeq ($(TOTAL_PARTITIONS),)
- @$(eval TOTAL_PARTITIONS := 1)
-endif
- @$(call submake,install-runtime RUNTIME="vfs2" ARGS="--vfs2")
- @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*")
+%-runtime-tests_vfs2: load-runtimes_% $(RUNTIME_BIN)
+ @$(call install_runtime,--vfs2)
+ @$(call test_runtime,--test_timeout=10800 //test/runtimes:$*)
-do-tests: runsc
- @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true")
- @$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true")
- @$(call submake,sudo TARGETS="//runsc" ARGS="do true")
+do-tests:
+ @$(call run,//runsc,--rootless do true)
+ @$(call run,//runsc,--rootless -network=none do true)
+ @$(call sudo,//runsc,do true)
.PHONY: do-tests
simple-tests: unit-tests # Compatibility target.
.PHONY: simple-tests
-docker-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="vfs1")
- @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)")
- @$(call submake,install-runtime RUNTIME="vfs2" ARGS="--vfs2")
- @$(call submake,test-runtime RUNTIME="vfs2" TARGETS="$(INTEGRATION_TARGETS)")
+# Standard integration targets.
+INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test
+
+docker-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,) # Clear flags.
+ @$(call test_runtime,$(INTEGRATION_TARGETS))
+ @$(call install_runtime,--vfs2)
+ @$(call test_runtime,$(INTEGRATION_TARGETS))
.PHONY: docker-tests
-overlay-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="overlay" ARGS="--overlay")
- @$(call submake,test-runtime RUNTIME="overlay" TARGETS="$(INTEGRATION_TARGETS)")
+overlay-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,--overlay)
+ @$(call test_runtime,$(INTEGRATION_TARGETS))
.PHONY: overlay-tests
-swgso-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false")
- @$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)")
+swgso-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,--software-gso=true --gso=false)
+ @$(call test_runtime,$(INTEGRATION_TARGETS))
.PHONY: swgso-tests
-hostnet-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="hostnet" ARGS="--network=host")
- @$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false" TARGETS="$(INTEGRATION_TARGETS)")
+hostnet-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,--network=host)
+ @$(call test_runtime,--test_arg=-checkpoint=false --test_arg=-hostnet=true $(INTEGRATION_TARGETS))
.PHONY: hostnet-tests
-kvm-tests: load-basic-images
+kvm-tests: load-basic $(RUNTIME_BIN)
@(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
- @if ! [[ -w /dev/kvm ]]; then sudo chmod a+rw /dev/kvm; fi
- @$(call submake,test TARGETS="//pkg/sentry/platform/kvm:kvm_test")
- @$(call submake,install-runtime RUNTIME="kvm" ARGS="--platform=kvm")
- @$(call submake,test-runtime RUNTIME="kvm" TARGETS="$(INTEGRATION_TARGETS)")
+ @if ! test -w /dev/kvm; then sudo chmod a+rw /dev/kvm; fi
+ @$(call test,//pkg/sentry/platform/kvm:kvm_test)
+ @$(call install_runtime,--platform=kvm)
+ @$(call test_runtime,$(INTEGRATION_TARGETS))
.PHONY: kvm-tests
-iptables-tests: load-iptables
+iptables-tests: load-iptables $(RUNTIME_BIN)
@sudo modprobe iptable_filter
@sudo modprobe ip6table_filter
- @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test")
- @$(call submake,install-runtime RUNTIME="iptables" ARGS="--net-raw")
- @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
+ @$(call test,--test_arg=-runtime=runc $(PARTITIONS) //test/iptables:iptables_test)
+ @$(call install_runtime,--net-raw)
+ @$(call test_runtime,//test/iptables:iptables_test)
.PHONY: iptables-tests
-# Run the iptables tests with runsc only. Useful for developing to skip runc
-# testing.
-iptables-runsc-tests: load-iptables
- @sudo modprobe iptable_filter
- @sudo modprobe ip6table_filter
- @$(call submake,install-runtime RUNTIME="iptables" ARGS="--net-raw")
- @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
-.PHONY: iptables-runsc-tests
-
-packetdrill-tests: load-packetdrill
- @$(call submake,install-runtime RUNTIME="packetdrill")
- @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetdrill, tests(//...))')")
+packetdrill-tests: load-packetdrill $(RUNTIME_BIN)
+ @$(call install_runtime,) # Clear flags.
+ @$(call test_runtime,//test/packetdrill:all_tests)
.PHONY: packetdrill-tests
-packetimpact-tests: load-packetimpact
+packetimpact-tests: load-packetimpact $(RUNTIME_BIN)
@sudo modprobe iptable_filter
@sudo modprobe ip6table_filter
- @$(call submake,install-runtime RUNTIME="packetimpact")
- @$(call submake,test-runtime OPTIONS="--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3" RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetimpact, tests(//...))')")
+ @$(call install_runtime,) # Clear flags.
+ @$(call test_runtime,--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3 //test/packetimpact/tests:all_tests)
.PHONY: packetimpact-tests
# Specific containerd version tests.
-containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu
- @$(call submake,install-runtime RUNTIME="root")
- @CONTAINERD_VERSION=$* $(MAKE) sudo TARGETS="tools/installers:containerd"
- @$(MAKE) sudo TARGETS="tools/installers:shim"
- @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="--runtime=root -test.v"
+containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu $(RUNTIME_BIN)
+ @$(call install_runtime,) # Clear flags.
+ @$(call sudo,tools/installers:containerd,$*)
+ @$(call sudo,tools/installers:shim)
+ @$(call sudo,test/root:root_test,--runtime=$(RUNTIME) -test.v)
# Note that we can't run containerd-test-1.1.8 tests here.
#
@@ -257,8 +299,8 @@ containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-b
# actually drive the tests. The v1 API is tested exclusively through 1.2.13.
containerd-tests: ## Runs all supported containerd version tests.
containerd-tests: containerd-test-1.2.13
-containerd-tests: containerd-test-1.3.4
-containerd-tests: containerd-test-1.4.0-beta.0
+containerd-tests: containerd-test-1.3.9
+containerd-tests: containerd-test-1.4.3
##
## Benchmarks.
@@ -284,35 +326,35 @@ BENCHMARKS_UPLOAD := false
BENCHMARKS_OFFICIAL := false
BENCHMARKS_PLATFORMS := ptrace
BENCHMARKS_TARGETS := //test/benchmarks/base:startup_test
-BENCHMARKS_ARGS := -test.bench=.
+BENCHMARKS_ARGS := -test.bench=. -pprof-cpu -pprof-heap -pprof-heap -pprof-block
-init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema
-## (see //tools/bigquery/bigquery.go). If the table alread exists, this is a noop.
- $(call submake, run TARGETS=//tools/parsers:parser ARGS="init --project=$(BENCHMARKS_PROJECT) \
- --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)")
+init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema.
+ @$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE))
.PHONY: init-benchmark-table
-benchmark-platforms: load-benchmarks-images ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
- $(call submake, run-benchmark RUNTIME="runc")
- $(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \
- $(call submake,install-runtime RUNTIME="$(PLATFORM)" ARGS="--platform=$(PLATFORM) --vfs2") && \
- $(call submake,run-benchmark RUNTIME="$(PLATFORM)") && \
- $(call submake,install-runtime RUNTIME="$(PLATFORM)_vfs1" ARGS="--platform=$(PLATFORM)") && \
- $(call submake,run-benchmark RUNTIME="$(PLATFORM)_vfs1") && \
+# $(1) is the runtime name, $(2) are the arguments.
+run_benchmark = \
+ $(call header,BENCHMARK $(1) $(2)); \
+ if test "$(1)" != "runc"; then $(call install_runtime,--profile $(2)); fi \
+ @T=$$(mktemp --tmpdir logs.$(RUNTIME).XXXXXX); \
+ $(call sudo,$(BENCHMARKS_TARGETS) --runtime=$(RUNTIME) $(BENCHMARKS_ARGS) | tee $$T); \
+ rc=$$?; \
+ if test $$rc -eq 0 && test "$(BENCHMARKS_UPLOAD)" == "true"; then \
+ $(call run,tools/parsers:parser parse --debug --file=$$T --runtime=$(RUNTIME) --suite_name=$(BENCHMARKS_SUITE) --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)); \
+ fi; \
+ rm -rf $$T; \
+ exit $$rc
+
+benchmark-platforms: load-benchmarks ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
+ @$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \
+ $(call run_benchmark,$(RUNTIME)+vfs2,$(BENCHMARK_ARGS) --platform=$(PLATFORM) --vfs2) && \
+ $(call run_benchmark,$(RUNTIME),$(BENCHMARK_ARGS) --platform=$(PLATFORM)) && \
) \
- true
+ $(call run-benchmark,runc)
.PHONY: benchmark-platforms
-run-benchmark: ## Runs single benchmark and optionally sends data to BigQuery.
- @set -xeuo pipefail; T=$$(mktemp --tmpdir logs.$(RUNTIME).XXXXXX); \
- $(call submake,sudo TARGETS="$(BENCHMARKS_TARGETS)" ARGS="--runtime=$(RUNTIME) $(BENCHMARKS_ARGS)" | tee $$T); \
- if [[ "$(BENCHMARKS_UPLOAD)" == "true" ]]; then \
- $(call submake,run TARGETS=tools/parsers:parser ARGS="parse --debug --file=$$T \
- --runtime=$(RUNTIME) --suite_name=$(BENCHMARKS_SUITE) \
- --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) \
- --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)"); \
- fi; \
- rm -rf $$T
+run-benchmark: load-benchmarks ## Runs single benchmark and optionally sends data to BigQuery.
+ @$(call run_benchmark,$(RUNTIME),$(BENCHMARK_ARGS))
.PHONY: run-benchmark
##
@@ -332,7 +374,7 @@ WEBSITE_PROJECT := gvisordev
WEBSITE_REGION := us-central1
website-build: load-jekyll ## Build the site image locally.
- @$(call submake,run TARGETS="//website:website" ARGS="$(WEBSITE_IMAGE)")
+ @$(call run,//website:website,$(WEBSITE_IMAGE))
.PHONY: website-build
website-server: website-build ## Run a local server for development.
@@ -358,17 +400,17 @@ website-deploy: website-push ## Deploy a new version of the website.
## RELEASE_NAME - The name of the release in the proper format (needed for tag).
## RELEASE_NOTES - The file containing release notes (needed for tag).
##
-RELEASE_ROOT := $(CURDIR)/repo
-RELEASE_KEY := repo.key
-RELEASE_NIGHTLY := false
-RELEASE_COMMIT :=
-RELEASE_NAME :=
-RELEASE_NOTES :=
-
+RELEASE_ROOT := $(CURDIR)/repo
+RELEASE_KEY := repo.key
+RELEASE_NIGHTLY := false
+RELEASE_COMMIT :=
+RELEASE_NAME :=
+RELEASE_NOTES :=
GPG_TEST_OPTIONS := $(shell if gpg --pinentry-mode loopback --version >/dev/null 2>&1; then echo --pinentry-mode loopback; fi)
+
$(RELEASE_KEY):
@echo "WARNING: Generating a key for testing ($@); don't use this."
- T=$$(mktemp --tmpdir keyring.XXXXXX); \
+ @T=$$(mktemp --tmpdir keyring.XXXXXX); \
C=$$(mktemp --tmpdir config.XXXXXX); \
echo Key-Type: DSA >> $$C && \
echo Key-Length: 1024 >> $$C && \
@@ -382,11 +424,11 @@ $(RELEASE_KEY):
release: $(RELEASE_KEY) ## Builds a release.
@mkdir -p $(RELEASE_ROOT)
- @T=$$(mktemp -d --tmpdir release.XXXXXX); \
- $(call submake,copy TARGETS="//runsc:runsc" DESTINATION=$$T) && \
- $(call submake,copy TARGETS="//shim/v1:gvisor-containerd-shim" DESTINATION=$$T) && \
- $(call submake,copy TARGETS="//shim/v2:containerd-shim-runsc-v1" DESTINATION=$$T) && \
- $(call submake,copy TARGETS="//debian:debian" DESTINATION=$$T) && \
+ @export T=$$(mktemp -d --tmpdir release.XXXXXX); \
+ $(call copy,//runsc:runsc,$$T) && \
+ $(call copy,//shim/v1:gvisor-containerd-shim,$$T) && \
+ $(call copy,//shim/v2:containerd-shim-runsc-v1,$$T) && \
+ $(call copy,//debian:debian,$$T) && \
NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \
rc=$$?; rm -rf $$T; exit $$rc
.PHONY: release
@@ -394,75 +436,3 @@ release: $(RELEASE_KEY) ## Builds a release.
tag: ## Creates and pushes a release tag.
@tools/tag_release.sh "$(RELEASE_COMMIT)" "$(RELEASE_NAME)" "$(RELEASE_NOTES)"
.PHONY: tag
-
-##
-## Development helpers and tooling.
-##
-## These targets faciliate local development by automatically
-## installing and configuring a runtime. Several variables may
-## be used here to tweak the installation:
-## RUNTIME - The name of the installed runtime (default: branch).
-## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME).
-## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc).
-## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs).
-## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%).
-##
-ifeq (,$(BRANCH_NAME))
-RUNTIME := runsc
-RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
-else
-RUNTIME := $(BRANCH_NAME)
-RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
-endif
-RUNTIME_BIN := $(RUNTIME_DIR)/runsc
-RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs
-RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
-
-dev: ## Installs a set of local runtimes. Requires sudo.
- @$(call submake,refresh ARGS="--net-raw")
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="--net-raw")
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets")
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-p" ARGS="--net-raw --profile")
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2")
- @sudo systemctl restart docker
-.PHONY: dev
-
-refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'install-runtime' first.
- @mkdir -p "$(RUNTIME_DIR)"
- @$(call submake,copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)")
-.PHONY: refresh
-
-install-runtime: ## Installs the runtime for testing. Requires sudo.
- @$(call submake,refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME $(ARGS)")
- @$(call submake,configure RUNTIME_NAME=runsc)
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)")
- @sudo systemctl restart docker
- @if [[ -f /etc/docker/daemon.json ]]; then \
- sudo chmod 0755 /etc/docker && \
- sudo chmod 0644 /etc/docker/daemon.json; \
- fi
-.PHONY: install-runtime
-
-install-debug-runtime: ## Installs the runtime for debugging. Requires sudo.
- @$(call submake,install-runtime ARGS="--debug --strace --log-packets $(ARGS)")
-.PHONY: install-debug-runtime
-
-configure: ## Configures a single runtime. Requires sudo. Typically called from dev or install-runtime.
- @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME_NAME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS)
- @echo -e "$(INFO) Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)"
- @echo -e "$(INFO) Logs are in: $(RUNTIME_LOG_DIR)"
- @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)"
-.PHONY: configure
-
-test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided.
- @$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)")
-.PHONY: test-runtime
-
-nogo: ## Surfaces all nogo findings.
- @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...")
- @$(call submake,run TARGETS="//tools/github" ARGS="$(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo")
-.PHONY: nogo
-
-gazelle: ## Runs gazelle to update WORKSPACE.
- @$(call submake,run TARGETS="//:gazelle" ARGS="update-repos -from_file=go.mod -prune")
-.PHONY: gazelle
diff --git a/WORKSPACE b/WORKSPACE
index 2f3408709..933c1ff19 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -38,6 +38,12 @@ http_archive(
http_archive(
name = "bazel_gazelle",
+ patch_args = ["-p1"],
+ patches = [
+ # False positive output complaining about Go logrus versions spam the
+ # logs. Strip this message in this case. Does not affect control flow.
+ "//tools:bazel_gazelle.patch",
+ ],
sha256 = "b85f48fa105c4403326e9525ad2b2cc437babaa6e15a3fc0b1dbab0ab064bc7c",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.2/bazel-gazelle-v0.22.2.tar.gz",
@@ -187,8 +193,8 @@ go_repository(
name = "com_github_containerd_containerd",
build_file_proto_mode = "disable",
importpath = "github.com/containerd/containerd",
- sum = "h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=",
- version = "v1.3.4",
+ sum = "h1:K2U/F4jGAMBqeUssfgJRbFuomLcS2Fxo1vR3UM/Mbh8=",
+ version = "v1.3.9",
)
go_repository(
@@ -518,8 +524,8 @@ go_repository(
name = "com_github_containerd_cgroups",
build_file_proto_mode = "disable",
importpath = "github.com/containerd/cgroups",
- sum = "h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY=",
- version = "v0.0.0-20181219155423-39b18af02c41",
+ sum = "h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw=",
+ version = "v0.0.0-20201119153540-4cbc285b3327",
)
go_repository(
diff --git a/g3doc/proposals/runtime_dedicate_os_thread.md b/g3doc/proposals/runtime_dedicate_os_thread.md
new file mode 100644
index 000000000..dc70055b0
--- /dev/null
+++ b/g3doc/proposals/runtime_dedicate_os_thread.md
@@ -0,0 +1,188 @@
+# `runtime.DedicateOSThread`
+
+Status as of 2020-09-18: Deprioritized; initial studies in #2180 suggest that
+this may be difficult to support in the Go runtime due to issues with GC.
+
+## Summary
+
+Allow goroutines to bind to kernel threads in a way that allows their scheduling
+to be kernel-managed rather than runtime-managed.
+
+## Objectives
+
+* Reduce Go runtime overhead in the gVisor sentry (#2184).
+
+* Minimize intrusiveness of changes to the Go runtime.
+
+## Background
+
+In Go, execution contexts are referred to as goroutines, which the runtime calls
+Gs. The Go runtime maintains a variably-sized pool of threads (called Ms by the
+runtime) on which Gs are executed, as well as a pool of "virtual processors"
+(called Ps by the runtime) of size equal to `runtime.GOMAXPROCS()`. Usually,
+each M requires a P in order to execute Gs, limiting the number of concurrently
+executing goroutines to `runtime.GOMAXPROCS()`.
+
+The `runtime.LockOSThread` function temporarily locks the invoking goroutine to
+its current thread. It is primarily useful for interacting with OS or non-Go
+library facilities that are per-thread. It does not reduce interactions with the
+Go runtime scheduler: locked Ms relinquish their P when they become blocked, and
+only continue execution after another M "chooses" their locked G to run and
+donates their P to the locked M instead.
+
+## Problems
+
+### Context Switch Overhead
+
+Most goroutines in the gVisor sentry are task goroutines, which back application
+threads. Task goroutines spend large amounts of time blocked on syscalls that
+execute untrusted application code. When invoking said syscall (which varies by
+gVisor platform), the task goroutine may interact with the Go runtime in one of
+three ways:
+
+* It can invoke the syscall without informing the runtime. In this case, the
+ task goroutine will continue to hold its P during the syscall, limiting the
+ number of application threads that can run concurrently to
+ `runtime.GOMAXPROCS()`. This is problematic because the Go runtime scheduler
+ is known to scale poorly with `GOMAXPROCS`; see #1942 and
+ https://github.com/golang/go/issues/28808. It also means that preemption of
+ application threads must be driven by sentry or runtime code, which is
+ strictly slower than kernel-driven preemption (since the sentry must invoke
+ another syscall to preempt the application thread).
+
+* It can call `runtime.entersyscallblock` before invoking the syscall, and
+ `runtime.exitsyscall` after the syscall returns. In this case, the task
+ goroutine will release its P while the syscall is executing. This allows the
+ number of threads concurrently executing application code to exceed
+ `GOMAXPROCS`. However, this incurs additional latency on syscall entry (to
+ hand off the released P to another M, often requiring a `futex(FUTEX_WAKE)`
+ syscall) and on syscall exit (to acquire a new P). It also drastically
+ increases the number of threads that concurrently interact with the runtime
+ scheduler, which is also problematic for performance (both in terms of CPU
+ utilization and in terms of context switch latency); see #205.
+
+- It can call `runtime.entersyscall` before invoking the syscall, and
+ `runtime.exitsyscall` after the syscall returns. In this case, the task
+ goroutine "lazily releases" its P, allowing the runtime's "sysmon" thread to
+ steal it on behalf of another M after a 20us delay. This mitigates the
+ context switch latency problem when there are few task goroutines and the
+ interval between switches to application code (i.e. the interval between
+ application syscalls, page faults, or signal delivery) is short. (Cynically,
+ this means that it's most effective in microbenchmarks). However, the delay
+ before a P is stolen can also be problematic for performance when there are
+ both many task goroutines switching to application code (lazily releasing
+ their Ps) *and* many task goroutines switching to sentry code (contending
+ for Ps), which is likely in larger heterogeneous workloads.
+
+### Blocking Overhead
+
+Task goroutines block on behalf of application syscalls like `futex` and
+`epoll_wait` by receiving from a Go channel. (Future work may convert task
+goroutine blocking to use the `syncevent` package to avoid overhead associated
+with channels and `select`, but this does not change how blocking interacts with
+the Go runtime scheduler.)
+
+If `runtime.LockOSThread()` is not in effect when a task goroutine blocks, then
+when the task goroutine is unblocked (by e.g. an application `FUTEX_WAKE`,
+signal delivery, or a timeout) by sending to the blocked channel,
+`runtime.ready` migrates the unblocked G to the unblocking P. In most cases,
+this implies that every application thread block/unblock cycle results in a
+migration of the thread between Ps, and therefore Ms, and therefore cores,
+resulting in reduced application performance due to loss of CPU caches.
+Furthermore, in most cases, the unblocking P cannot immediately switch to the
+unblocked G (instead resuming execution of its current application thread after
+completing the application's `futex(FUTEX_WAKE)`, `tgkill`, etc. syscall), often
+requiring that another P steal the unblocked G before it can resume execution.
+
+If `runtime.LockOSThread()` is in effect when a task goroutine blocks, then the
+G will remain locked to its M, avoiding the core migration described above;
+however, wakeup latency is significantly increased since, as described in
+"Background", the G still needs to be selected by the scheduler before it can
+run, and the M that selects the G then needs to transfer its P to the locked M,
+incurring an additional `FUTEX_WAKE` syscall and round of kernel scheduling.
+
+## Proposal
+
+We propose to add a function, tentatively called `DedicateOSThread`, to the Go
+`runtime` package, documented as follows:
+
+```go
+// DedicateOSThread wires the calling goroutine to its current operating system
+// thread, and exempts it from counting against GOMAXPROCS. The calling
+// goroutine will always execute in that thread, and no other goroutine will
+// execute in it, until the calling goroutine has made as many calls to
+// UndedicateOSThread as to DedicateOSThread. If the calling goroutine exits
+// without unlocking the thread, the thread will be terminated.
+//
+// DedicateOSThread should only be used by long-lived goroutines that usually
+// block due to blocking system calls, rather than interaction with other
+// goroutines.
+func DedicateOSThread()
+```
+
+Mechanically, `DedicateOSThread` implies `LockOSThread` (i.e. it locks the
+invoking G to a M), but additionally locks the invoking M to a P. Ps locked by
+`DedicateOSThread` are not counted against `GOMAXPROCS`; that is, the actual
+number of Ps in the system (`len(runtime.allp)`) is `GOMAXPROCS` plus the number
+of bound Ps (plus some slack to avoid frequent changes to `runtime.allp`).
+Corollaries:
+
+* If `runtime.ready` observes that a readied G is locked to a M locked to a P,
+ it immediately wakes the locked M without migrating the G to the readying P
+ or waiting for a future call to `runtime.schedule` to select the readied G
+ in `runtime.findrunnable`.
+
+* `runtime.stoplockedm` and `runtime.reentersyscall` skip the release of
+ locked Ps; the latter also skips sysmon wakeup. `runtime.stoplockedm` and
+ `runtime.exitsyscall` skip re-acquisition of Ps if one is locked.
+
+* sysmon does not attempt to preempt Gs that are locked to Ps, avoiding
+ fruitless overhead from `tgkill` syscalls and signal delivery.
+
+* `runtime.findrunnable`'s work stealing skips locked Ps (suggesting that
+ unlocked Ps be tracked in a separate array). `runtime.findrunnable` on
+ locked Ps skip the global run queue, work stealing, and possibly netpoll.
+
+* New goroutines created by goroutines with locked Ps are enqueued on the
+ global run queue rather than the invoking P's local run queue.
+
+While gVisor's use case does not strictly require that the association is
+reversible (with `runtime.UndedicateOSThread`), such a feature is required to
+allow reuse of locked Ms, which is likely to be critical for performance.
+
+## Alternatives Considered
+
+* Make the runtime scale well with `GOMAXPROCS`. While we are also
+ concurrently investigating this problem, this would not address the issues
+ of increased preemption cost or blocking overhead.
+
+* Make the runtime scale well with number of Ms. It is unclear if this is
+ actually feasible, and would not address blocking overhead.
+
+* Make P-locking part of `LockOSThread`'s behavior. This would likely
+ introduce performance regressions in existing uses of `LockOSThread` that do
+ not fit this usage pattern. In particular, since `DedicateOSThread`
+ transitions the invoker's P from "counted against `GOMAXPROCS`" to "not
+ counted against `GOMAXPROCS`", it may need to wake another M to run a new P
+ (that is counted against `GOMAXPROCS`), and the converse applies to
+ `UndedicateOSThread`.
+
+* Rewrite the gVisor sentry in a language that does not force userspace
+ scheduling. This is a last resort due to the amount of code involved.
+
+## Related Issues
+
+The proposed functionality is directly analogous to `spawn_blocking` in Rust
+async runtimes
+[`async_std`](https://docs.rs/async-std/1.8.0/async_std/task/fn.spawn_blocking.html)
+and [`tokio`](https://docs.rs/tokio/0.3.5/tokio/task/fn.spawn_blocking.html).
+
+Outside of gVisor:
+
+* https://github.com/golang/go/issues/21827#issuecomment-595152452 describes a
+ use case for this feature in go-delve, where the goroutine that would use
+ this feature spends much of its time blocked in `ptrace` syscalls.
+
+* This feature may improve performance in the use case described in
+ https://github.com/golang/go/issues/18237, given the prominence of
+ syscall.Syscall in the profile given in that bug report.
diff --git a/go.mod b/go.mod
index 144543169..823c3596d 100644
--- a/go.mod
+++ b/go.mod
@@ -10,8 +10,8 @@ require (
github.com/Microsoft/hcsshim v0.8.6 // indirect
github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect
github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3 // indirect
- github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect
- github.com/containerd/containerd v1.3.4 // indirect
+ github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327
+ github.com/containerd/containerd v1.3.9 // indirect
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect
github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect
diff --git a/go.sum b/go.sum
index 060d5596a..70514ea14 100644
--- a/go.sum
+++ b/go.sum
@@ -51,12 +51,12 @@ github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssN
github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI=
github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59 h1:qWj4qVYZ95vLWwqyNJCQg7rDsG5wPdze0UaPolH7DUk=
github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59/go.mod h1:pA0z1pT8KYB3TCXK/ocprsh7MAkoW8bZVzPdih9snmM=
+github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw=
+github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw=
github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc=
github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE=
-github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
-github.com/containerd/containerd v1.3.4 h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=
-github.com/containerd/containerd v1.3.4/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
+github.com/containerd/containerd v1.3.9/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y=
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM=
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a/go.mod h1:W0qIOTD7mp2He++YVq+kgfXezRYqzP1uDuMVH1bITDY=
diff --git a/images/BUILD b/images/BUILD
index a50f388e9..34b950644 100644
--- a/images/BUILD
+++ b/images/BUILD
@@ -1,11 +1 @@
package(licenses = ["notice"])
-
-# The images filegroup is definitely not a hermetic target, and requires Make
-# to do anything meaningful with. However, this will be slurped up and used by
-# the tools/installer/images.sh installer, which will ensure that all required
-# images are available locally when running vm_tests.
-filegroup(
- name = "images",
- srcs = glob(["**"]),
- visibility = ["//tools/installers:__pkg__"],
-)
diff --git a/images/Makefile b/images/Makefile
deleted file mode 100644
index 12927c509..000000000
--- a/images/Makefile
+++ /dev/null
@@ -1,107 +0,0 @@
-#!/usr/bin/make -f
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# ARCH is the architecture used for the build. This may be overriden at the
-# command line in order to perform a cross-build (in a limited capacity).
-ARCH := $(shell uname -m)
-
-# Note that the image prefixes used here must match the image mangling in
-# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all
-# tests are using locally-defined images (that are consistent and idempotent).
-REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit
-LOCAL_IMAGE_PREFIX ?= gvisor.dev/images
-ALL_IMAGES := $(subst /,_,$(subst ./,,$(shell find . -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq)))
-ifneq ($(ARCH),$(shell uname -m))
-DOCKER_PLATFORM_ARGS := --platform=$(ARCH)
-else
-DOCKER_PLATFORM_ARGS :=
-endif
-
-list-all-images:
- @for image in $(ALL_IMAGES); do echo $${image}; done
-.PHONY: list-build-images
-
-# Handy wrapper to allow load-all-images, push-all-images, etc.
-%-all-images:
- @$(MAKE) $(patsubst %,$*-%,$(ALL_IMAGES))
-load-all-images:
- @$(MAKE) $(patsubst %,load-%,$(ALL_IMAGES))
-
-# Handy wrapper to load specified "groups", e.g. load-basic-images, etc.
-load-%-images:
- @$(MAKE) $(patsubst %,load-%,$(subst /,_,$(subst ./,,$(shell find ./$* -name Dockerfile -exec dirname {} \;))))
-
-# tag is a function that returns the tag name, given an image.
-#
-# The tag constructed is used to memoize the image generated (see README.md).
-# This scheme is used to enable aggressive caching in a central repository, but
-# ensuring that images will always be sourced using the local files if there
-# are changes.
-path = $(subst _,/,$(1))
-dockerfile = $$(if [ -f "$(call path,$(1))/Dockerfile.$(ARCH)" ]; then echo Dockerfile.$(ARCH); else echo Dockerfile; fi)
-tag = $(shell find $(call path,$(1)) -type f -print | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16)
-remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH):$(call tag,$(1))
-local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1))
-
-# rebuild builds the image locally. Only the "remote" tag will be applied. Note
-# we need to explicitly repull the base layer in order to ensure that the
-# architecture is correct. Note that we use the term "rebuild" here to avoid
-# conflicting with the bazel "build" terminology, which is used elsewhere.
-rebuild-%: FROM=$(shell grep FROM "$(call path,$*)/$(call dockerfile,$*)" | cut -d' ' -f2)
-rebuild-%: register-cross
- @if ! [ -f "$(call path,$*)/$(call dockerfile,$*)" ]; then \
- (echo "ERROR: Dockerfile for $* not found (is it available for $(ARCH)?)." >&2 && exit 1); \
- fi
- $(foreach IMAGE,$(FROM),docker pull $(DOCKER_PLATFORM_ARGS) $(IMAGE) &&) \
- T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \
- docker build $(DOCKER_PLATFORM_ARGS) \
- -f "$$T/$(call dockerfile,$*)" \
- -t "$(call remote_image,$*)" \
- $$T && \
- rm -rf $$T
-
-# pull will check the "remote" image and pull if necessary. If the remote image
-# must be pulled, then it will tag with the latest local target. Note that pull
-# may fail if the remote image is not available.
-pull-%:
- docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*)
-
-# load will either pull the "remote" or build it locally. This is the preferred
-# entrypoint, as it should never fail. The local tag should always be set after
-# this returns (either by the pull or the build).
-load-%:
- $(MAKE) pull-$* || $(MAKE) rebuild-$*
- docker tag $(call remote_image,$*) $(call local_image,$*)
-
-# push pushes the remote image, after either pulling (to validate that the tag
-# already exists) or building manually.
-push-%: load-%
- docker push $(call remote_image,$*)
-
-# register-cross registers the necessary qemu binaries for cross-compilation.
-# This may be used by any target that may execute containers that are not the
-# native format.
-register-cross:
-ifneq ($(ARCH),$(shell uname -m))
-ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*))
- docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes
-else
- @true # Already registered.
-endif
-else
- @true # No cross required.
-endif
-.PHONY: register-cross
diff --git a/images/agent/Dockerfile b/images/agent/Dockerfile
new file mode 100644
index 000000000..1d8979390
--- /dev/null
+++ b/images/agent/Dockerfile
@@ -0,0 +1,12 @@
+FROM golang:1.15 as build-agent
+RUN git clone --depth=1 --branch=v3.25.0 https://github.com/buildkite/agent
+RUN cd agent && go build -i -o /buildkite-agent .
+
+FROM golang:1.15 as build-agent-metrics
+RUN git clone --depth=1 --branch=v5.2.0 https://github.com/buildkite/buildkite-agent-metrics
+RUN cd buildkite-agent-metrics && go build -i -o /buildkite-agent-metrics .
+
+FROM gcr.io/distroless/base-debian10
+COPY --from=build-agent /buildkite-agent /
+COPY --from=build-agent-metrics /buildkite-agent-metrics /
+CMD ["/buildkite-agent"]
diff --git a/images/agent/README.md b/images/agent/README.md
new file mode 100644
index 000000000..acb57bd2f
--- /dev/null
+++ b/images/agent/README.md
@@ -0,0 +1,7 @@
+# Build Agent
+
+This is the image used by the build agent. It is built and bundled via a
+separate packaging mechanism in order to provide local caching and to ensure
+that there is better build provenance. Note that continuous integration system
+will generally deploy new agents from the primary branch, and will only deploy
+as instances are recycled. Updates to this image should be made carefully.
diff --git a/images/basic/ping4test/Dockerfile b/images/basic/ping4test/Dockerfile
new file mode 100644
index 000000000..1536be376
--- /dev/null
+++ b/images/basic/ping4test/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:bionic
+
+WORKDIR /root
+COPY ping4.sh .
+RUN chmod +x ping4.sh
+
+RUN apt-get update && apt-get install -y iputils-ping
diff --git a/tools/vm/zone.sh b/images/basic/ping4test/ping4.sh
index 79569fb19..2a343712a 100755..100644
--- a/tools/vm/zone.sh
+++ b/images/basic/ping4test/ping4.sh
@@ -14,4 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-exec gcloud config get-value compute/zone
+set -euo pipefail
+
+# The docker API doesn't provide for starting a container, running a command,
+# and getting the exit status of the command in one go. The most straightforward
+# way to do this is to verify the output of the command, so we output nothing on
+# success and an error message on failure.
+if ! out=$(ping -c 10 127.0.0.1); then
+ echo "$out"
+fi
diff --git a/images/basic/ping6test/Dockerfile b/images/basic/ping6test/Dockerfile
new file mode 100644
index 000000000..cb740bd60
--- /dev/null
+++ b/images/basic/ping6test/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:bionic
+
+WORKDIR /root
+COPY ping6.sh .
+RUN chmod +x ping6.sh
+
+RUN apt-get update && apt-get install -y iputils-ping iproute2
diff --git a/images/basic/ping6test/ping6.sh b/images/basic/ping6test/ping6.sh
new file mode 100644
index 000000000..4268951d0
--- /dev/null
+++ b/images/basic/ping6test/ping6.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euo pipefail
+
+# Enable ipv6 on loopback if it's not already enabled. Runsc doesn't enable ipv6
+# loopback unless an ipv6 address was assigned to the container, which docker
+# does not do by default.
+if ! [[ $(ip -6 addr show dev lo) ]]; then
+ ip addr add ::1 dev lo
+fi
+
+# The docker API doesn't provide for starting a container, running a command,
+# and getting the exit status of the command in one go. The most straightforward
+# way to do this is to verify the output of the command, so we output nothing on
+# success and an error message on failure.
+if ! out=$(/bin/ping6 -c 10 ::1); then
+ echo "$out"
+fi
diff --git a/images/benchmarks/absl/Dockerfile b/images/benchmarks/absl/Dockerfile.x86_64
index b0dd97695..810c9ef5e 100644
--- a/images/benchmarks/absl/Dockerfile
+++ b/images/benchmarks/absl/Dockerfile.x86_64
@@ -12,6 +12,7 @@ RUN set -x \
unzip \
python3 \
&& rm -rf /var/lib/apt/lists/*
+
RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh
RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh
RUN ./bazel-0.27.0-installer-linux-x86_64.sh
diff --git a/images/benchmarks/hey/Dockerfile b/images/benchmarks/hey/Dockerfile
index f586978b6..4b6a0f849 100644
--- a/images/benchmarks/hey/Dockerfile
+++ b/images/benchmarks/hey/Dockerfile
@@ -1,12 +1,13 @@
-FROM ubuntu:18.04
+FROM golang:1.15 as build
+RUN go get github.com/rakyll/hey
+WORKDIR /go/src/github.com/rakyll/hey
+RUN go mod download
+RUN CGO_ENABLED=0 go build -o /hey hey.go
+FROM ubuntu:18.04
RUN set -x \
&& apt-get update \
&& apt-get install -y \
wget \
&& rm -rf /var/lib/apt/lists/*
-
-RUN wget https://storage.googleapis.com/hey-release/hey_linux_amd64 \
- && chmod 777 hey_linux_amd64 \
- && cp hey_linux_amd64 /bin/hey \
- && rm hey_linux_amd64
+COPY --from=build /hey /bin/hey
diff --git a/images/benchmarks/runsc/Dockerfile b/images/benchmarks/runsc/Dockerfile.x86_64
index 6c3aafa57..28ae64816 100644
--- a/images/benchmarks/runsc/Dockerfile
+++ b/images/benchmarks/runsc/Dockerfile.x86_64
@@ -14,6 +14,7 @@ RUN set -x \
python3 \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
+
RUN wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-installer-linux-x86_64.sh
RUN chmod +x bazel-3.4.1-installer-linux-x86_64.sh
RUN ./bazel-3.4.1-installer-linux-x86_64.sh
diff --git a/images/default/Dockerfile b/images/default/Dockerfile
index d058b83cb..224469267 100644
--- a/images/default/Dockerfile
+++ b/images/default/Dockerfile
@@ -1,16 +1,20 @@
FROM fedora:31
+
# Install bazel.
RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch diffutils
RUN pip install --no-cache-dir pycparser
RUN dnf install -y bazel3
-# Install gcloud.
+
+# Install gcloud. Note that while this is "x86_64", it doesn't actually matter.
RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \
- tar zxvf - google-cloud-sdk && \
+ tar zxf - google-cloud-sdk && \
google-cloud-sdk/install.sh && \
ln -s /google-cloud-sdk/bin/gcloud /usr/bin/gcloud
+
# Install Docker client for the website build.
RUN dnf config-manager --add-repo https://download.docker.com/linux/fedora/docker-ce.repo
RUN dnf install -y docker-ce-cli
+
WORKDIR /workspace
ENTRYPOINT ["/usr/bin/bazel"]
diff --git a/images/runtimes/go1.12/Dockerfile b/images/runtimes/go1.12/Dockerfile.x86_64
index cb2944062..cb2944062 100644
--- a/images/runtimes/go1.12/Dockerfile
+++ b/images/runtimes/go1.12/Dockerfile.x86_64
diff --git a/nogo.yaml b/nogo.yaml
index 5c1737f59..0a5ca78dc 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -56,124 +56,8 @@ global:
- "should not use ALL_CAPS in Go names"
- "should not use underscores in Go names"
exclude:
- # A variety of staticcheck and stylecheck
- # rules apply here. These should be fixed
- # and removed from here, and the global
- # rules should be used sparingly.
- - pkg/abi/linux/fuse.go:22
- - pkg/abi/linux/fuse.go:25
- - pkg/abi/linux/socket.go:113
- - pkg/abi/linux/tty.go:73
- - pkg/cpuid/cpuid_x86.go:675
- - pkg/gohacks/gohacks_unsafe.go:33
- - pkg/log/json.go:30
- - pkg/log/log.go:359
- - pkg/metric/metric_test.go:20
- - pkg/p9/p9test/client_test.go:687
- - pkg/p9/transport_test.go:196
- - pkg/pool/pool.go:15
- - pkg/refs/refcounter.go:510
- - pkg/refs/refcounter_test.go:169
- - pkg/refs_vfs2/refs.go:16
- - pkg/safemem/block_unsafe.go:89
- - pkg/seccomp/seccomp.go:82
- - pkg/segment/test/set_functions.go:15
- - pkg/sentry/arch/signal.go:166
- - pkg/sentry/arch/signal.go:171
- - pkg/sentry/control/pprof.go:196
- - pkg/sentry/devices/memdev/full.go:58
- - pkg/sentry/devices/memdev/null.go:59
- - pkg/sentry/devices/memdev/random.go:68
- - pkg/sentry/devices/memdev/zero.go:86
- - pkg/sentry/fdimport/fdimport.go:15
- - pkg/sentry/fs/attr.go:257
- - pkg/sentry/fsbridge/fs.go:116
- - pkg/sentry/fsbridge/vfs.go:124
- - pkg/sentry/fsbridge/vfs.go:70
- - pkg/sentry/fs/copy_up.go:365
- - pkg/sentry/fs/copy_up_test.go:65
- - pkg/sentry/fs/dev/net_tun.go:161
- - pkg/sentry/fs/dev/net_tun.go:63
- - pkg/sentry/fs/dev/null.go:97
- - pkg/sentry/fs/dirent_cache.go:64
- - pkg/sentry/fs/fdpipe/pipe_opener_test.go:366
- - pkg/sentry/fs/file_overlay.go:327
- - pkg/sentry/fs/file_overlay.go:524
- - pkg/sentry/fs/filetest/filetest.go:55
- - pkg/sentry/fs/filetest/filetest.go:60
- - pkg/sentry/fs/fs.go:77
- - pkg/sentry/fs/fsutil/file.go:290
- - pkg/sentry/fs/fsutil/file.go:346
- - pkg/sentry/fs/fsutil/host_file_mapper.go:105
- - pkg/sentry/fs/fsutil/inode_cached.go:676
- - pkg/sentry/fs/fsutil/inode_cached.go:772
- - pkg/sentry/fs/gofer/attr.go:120
- - pkg/sentry/fs/gofer/fifo.go:33
- - pkg/sentry/fs/gofer/inode.go:410
- - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97
- - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92
- - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44
- - pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91
- - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93
- - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66
- - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53
- - pkg/sentry/fsimpl/fuse/request_response.go:71
- - pkg/sentry/fsimpl/signalfd/signalfd.go:15
- - pkg/sentry/memmap/memmap.go:103
- - pkg/sentry/memmap/memmap.go:163
- - pkg/sentry/mm/aio_context.go:208
- - pkg/sentry/mm/pma.go:683
- - pkg/sentry/usage/cpu.go:42
- - pkg/shim/runsc/runsc.go:16
- - pkg/shim/runsc/utils.go:16
- - pkg/shim/v1/proc/deleted_state.go:16
- - pkg/shim/v1/proc/exec.go:16
- - pkg/shim/v1/proc/exec_state.go:16
- - pkg/shim/v1/proc/init.go:16
- - pkg/shim/v1/proc/init_state.go:16
- - pkg/shim/v1/proc/io.go:16
- - pkg/shim/v1/proc/process.go:16
- - pkg/shim/v1/proc/types.go:16
- - pkg/shim/v1/proc/utils.go:16
- - pkg/shim/v1/shim/api.go:16
- - pkg/shim/v1/shim/platform.go:16
- - pkg/shim/v1/shim/service.go:16
- - pkg/shim/v1/utils/annotations.go:15
- - pkg/shim/v1/utils/utils.go:15
- - pkg/shim/v1/utils/volumes.go:15
- - pkg/shim/v2/api.go:16
- - pkg/shim/v2/epoll.go:18
- - pkg/shim/v2/options/options.go:15
- - pkg/shim/v2/options/options.go:24
- - pkg/shim/v2/options/options.go:26
- - pkg/shim/v2/runtimeoptions/runtimeoptions.go:16
- - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go # Generated: exempt all.
- - pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22
- - pkg/shim/v2/service.go:15
- - pkg/shim/v2/service_linux.go:18
- - pkg/state/tests/integer_test.go:23
- - pkg/state/tests/integer_test.go:28
- - pkg/sync/rwmutex_test.go:105
- - pkg/syserr/host_linux.go:35
- - pkg/usermem/addr.go:34
- - pkg/usermem/usermem.go:171
- - pkg/usermem/usermem.go:170
- - runsc/boot/compat.go:56
- - test/cmd/test_app/fds.go:171
- - test/iptables/filter_output.go:251
- - test/packetimpact/testbench/connections.go:77
- - tools/bigquery/bigquery.go:106
- - tools/checkescape/test1/test1.go:108
- - tools/checkescape/test1/test1.go:122
- - tools/checkescape/test1/test1.go:137
- - tools/checkescape/test1/test1.go:151
- - tools/checkescape/test1/test1.go:170
- - tools/checkescape/test1/test1.go:39
- - tools/checkescape/test1/test1.go:45
- - tools/checkescape/test1/test1.go:50
- - tools/checkescape/test1/test1.go:64
- - tools/checkescape/test1/test1.go:80
- - tools/checkescape/test1/test1.go:94
+ # Generated: exempt all.
+ - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go
analyzers:
asmdecl:
external: # Enabled.
@@ -215,6 +99,8 @@ analyzers:
printf:
external: # Enabled.
shift:
+ generated: # Disabled for generated code; these shifts are well-defined.
+ exclude: [".*"]
external: # Enabled.
stringintconv:
external:
@@ -251,3 +137,22 @@ analyzers:
external: # Enabled.
checkescape:
external: # Enabled.
+ SA4016:
+ internal:
+ exclude:
+ - pkg/gohacks/gohacks_unsafe.go # x ^ 0 always equals x.
+ SA2001:
+ internal:
+ exclude:
+ - pkg/sentry/fs/fs.go # Intentional.
+ - pkg/sentry/fs/gofer/inode.go # Intentional.
+ - pkg/refs/refcounter_test.go # Intentional.
+ ST1021:
+ internal:
+ suppress:
+ - "comment on exported type Translation" # Intentional.
+ - "comment on exported type PinnedRange" # Intentional.
+ SA5011:
+ internal:
+ exclude:
+ - pkg/sentry/fs/fdpipe/pipe_opener_test.go # False positive.
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index cc3571fad..d1ca56370 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -25,6 +25,8 @@ const (
F_SETLKW = 7
F_SETOWN = 8
F_GETOWN = 9
+ F_SETSIG = 10
+ F_GETSIG = 11
F_SETOWN_EX = 15
F_GETOWN_EX = 16
F_DUPFD_CLOEXEC = 1024 + 6
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
index d91c97a64..1070b457c 100644
--- a/pkg/abi/linux/fuse.go
+++ b/pkg/abi/linux/fuse.go
@@ -19,16 +19,22 @@ import (
"gvisor.dev/gvisor/pkg/marshal/primitive"
)
+// FUSEOpcode is a FUSE operation code.
+//
// +marshal
type FUSEOpcode uint32
+// FUSEOpID is a FUSE operation ID.
+//
// +marshal
type FUSEOpID uint64
// FUSE_ROOT_ID is the id of root inode.
const FUSE_ROOT_ID = 1
-// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
+// Opcodes for FUSE operations.
+//
+// Analogous to the opcodes in include/linux/fuse.h.
const (
FUSE_LOOKUP FUSEOpcode = 1
FUSE_FORGET = 2 /* no reply */
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index 1b2f76c0b..2424884c1 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -32,6 +32,23 @@ const (
SEM_STAT_ANY = 20
)
+// Information about system-wide sempahore limits and parameters.
+//
+// Source: include/uapi/linux/sem.h
+const (
+ SEMMNI = 32000
+ SEMMSL = 32000
+ SEMMNS = SEMMNI * SEMMSL
+ SEMOPM = 500
+ SEMVMX = 32767
+ SEMAEM = SEMVMX
+
+ SEMUME = SEMOPM
+ SEMMNU = SEMMNS
+ SEMMAP = SEMMNS
+ SEMUSZ = 20
+)
+
const SEM_UNDO = 0x1000
// Sembuf is equivalent to struct sembuf.
@@ -42,3 +59,21 @@ type Sembuf struct {
SemOp int16
SemFlg int16
}
+
+// SemInfo is equivalent to struct seminfo.
+//
+// Source: include/uapi/linux/sem.h
+//
+// +marshal
+type SemInfo struct {
+ SemMap uint32
+ SemMni uint32
+ SemMns uint32
+ SemMnu uint32
+ SemMsl uint32
+ SemOpm uint32
+ SemUme uint32
+ SemUsz uint32
+ SemVmx uint32
+ SemAem uint32
+}
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index d156d41e4..556892dc3 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -111,12 +111,12 @@ type SockType int
// Socket types, from linux/net.h.
const (
SOCK_STREAM SockType = 1
- SOCK_DGRAM = 2
- SOCK_RAW = 3
- SOCK_RDM = 4
- SOCK_SEQPACKET = 5
- SOCK_DCCP = 6
- SOCK_PACKET = 10
+ SOCK_DGRAM SockType = 2
+ SOCK_RAW SockType = 3
+ SOCK_RDM SockType = 4
+ SOCK_SEQPACKET SockType = 5
+ SOCK_DCCP SockType = 6
+ SOCK_PACKET SockType = 10
)
// SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are
@@ -448,6 +448,8 @@ type ControlMessageCredentials struct {
// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
//
// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+//
+// +stateify savable
type ControlMessageIPPacketInfo struct {
NIC int32
LocalAddr InetAddr
diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go
index a4f4b2c5e..fdfe31417 100644
--- a/pkg/coverage/coverage.go
+++ b/pkg/coverage/coverage.go
@@ -27,6 +27,7 @@ import (
"io"
"sort"
"sync/atomic"
+ "testing"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -34,12 +35,6 @@ import (
"github.com/bazelbuild/rules_go/go/tools/coverdata"
)
-// KcovAvailable returns whether the kcov coverage interface is available. It is
-// available as long as coverage is enabled for some files.
-func KcovAvailable() bool {
- return len(coverdata.Cover.Blocks) > 0
-}
-
// coverageMu must be held while accessing coverdata.Cover. This prevents
// concurrent reads/writes from multiple threads collecting coverage data.
var coverageMu sync.RWMutex
@@ -47,6 +42,22 @@ var coverageMu sync.RWMutex
// once ensures that globalData is only initialized once.
var once sync.Once
+// blockBitLength is the number of bits used to represent coverage block index
+// in a synthetic PC (the rest are used to represent the file index). Even
+// though a PC has 64 bits, we only use the lower 32 bits because some users
+// (e.g., syzkaller) may truncate that address to a 32-bit value.
+//
+// As of this writing, there are ~1200 files that can be instrumented and at
+// most ~1200 blocks per file, so 16 bits is more than enough to represent every
+// file and every block.
+const blockBitLength = 16
+
+// KcovAvailable returns whether the kcov coverage interface is available. It is
+// available as long as coverage is enabled for some files.
+func KcovAvailable() bool {
+ return len(coverdata.Cover.Blocks) > 0
+}
+
var globalData struct {
// files is the set of covered files sorted by filename. It is calculated at
// startup.
@@ -104,14 +115,14 @@ var coveragePool = sync.Pool{
// coverage tools, we reset the global coverage data every time this function is
// run.
func ConsumeCoverageData(w io.Writer) int {
- once.Do(initCoverageData)
+ InitCoverageData()
coverageMu.Lock()
defer coverageMu.Unlock()
total := 0
var pcBuffer [8]byte
- for fileIndex, file := range globalData.files {
+ for fileNum, file := range globalData.files {
counters := coverdata.Cover.Counters[file]
for index := 0; index < len(counters); index++ {
if atomic.LoadUint32(&counters[index]) == 0 {
@@ -119,7 +130,7 @@ func ConsumeCoverageData(w io.Writer) int {
}
// Non-zero coverage data found; consume it and report as a PC.
atomic.StoreUint32(&counters[index], 0)
- pc := globalData.syntheticPCs[fileIndex][index]
+ pc := globalData.syntheticPCs[fileNum][index]
usermem.ByteOrder.PutUint64(pcBuffer[:], pc)
n, err := w.Write(pcBuffer[:])
if err != nil {
@@ -142,31 +153,84 @@ func ConsumeCoverageData(w io.Writer) int {
return total
}
-// initCoverageData initializes globalData. It should only be called once,
-// before any kcov data is written.
-func initCoverageData() {
- // First, order all files. Then calculate synthetic PCs for every block
- // (using the well-defined ordering for files as well).
- for file := range coverdata.Cover.Blocks {
- globalData.files = append(globalData.files, file)
+// InitCoverageData initializes globalData. It should be called before any kcov
+// data is written.
+func InitCoverageData() {
+ once.Do(func() {
+ // First, order all files. Then calculate synthetic PCs for every block
+ // (using the well-defined ordering for files as well).
+ for file := range coverdata.Cover.Blocks {
+ globalData.files = append(globalData.files, file)
+ }
+ sort.Strings(globalData.files)
+
+ for fileNum, file := range globalData.files {
+ blocks := coverdata.Cover.Blocks[file]
+ pcs := make([]uint64, 0, len(blocks))
+ for blockNum := range blocks {
+ pcs = append(pcs, calculateSyntheticPC(fileNum, blockNum))
+ }
+ globalData.syntheticPCs = append(globalData.syntheticPCs, pcs)
+ }
+ })
+}
+
+// Symbolize prints information about the block corresponding to pc.
+func Symbolize(out io.Writer, pc uint64) error {
+ fileNum, blockNum := syntheticPCToIndexes(pc)
+ file, err := fileFromIndex(fileNum)
+ if err != nil {
+ return err
+ }
+ block, err := blockFromIndex(file, blockNum)
+ if err != nil {
+ return err
}
- sort.Strings(globalData.files)
-
- // nextSyntheticPC is the first PC that we generate for a block.
- //
- // This uses a standard-looking kernel range for simplicity.
- //
- // FIXME(b/160639712): This is only necessary because syzkaller requires
- // addresses in the kernel range. If we can remove this constraint, then we
- // should be able to use the actual addresses.
- var nextSyntheticPC uint64 = 0xffffffff80000000
- for _, file := range globalData.files {
- blocks := coverdata.Cover.Blocks[file]
- thisFile := make([]uint64, 0, len(blocks))
- for range blocks {
- thisFile = append(thisFile, nextSyntheticPC)
- nextSyntheticPC++ // Advance.
+ writeBlock(out, pc, file, block)
+ return nil
+}
+
+// WriteAllBlocks prints all information about all blocks along with their
+// corresponding synthetic PCs.
+func WriteAllBlocks(out io.Writer) {
+ for fileNum, file := range globalData.files {
+ for blockNum, block := range coverdata.Cover.Blocks[file] {
+ writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block)
}
- globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile)
}
}
+
+func calculateSyntheticPC(fileNum int, blockNum int) uint64 {
+ return (uint64(fileNum) << blockBitLength) + uint64(blockNum)
+}
+
+func syntheticPCToIndexes(pc uint64) (fileNum int, blockNum int) {
+ return int(pc >> blockBitLength), int(pc & ((1 << blockBitLength) - 1))
+}
+
+// fileFromIndex returns the name of the file in the sorted list of instrumented files.
+func fileFromIndex(i int) (string, error) {
+ total := len(globalData.files)
+ if i < 0 || i >= total {
+ return "", fmt.Errorf("file index out of range: [%d] with length %d", i, total)
+ }
+ return globalData.files[i], nil
+}
+
+// blockFromIndex returns the i-th block in the given file.
+func blockFromIndex(file string, i int) (testing.CoverBlock, error) {
+ blocks, ok := coverdata.Cover.Blocks[file]
+ if !ok {
+ return testing.CoverBlock{}, fmt.Errorf("instrumented file %s does not exist", file)
+ }
+ total := len(blocks)
+ if i < 0 || i >= total {
+ return testing.CoverBlock{}, fmt.Errorf("block index out of range: [%d] with length %d", i, total)
+ }
+ return blocks[i], nil
+}
+
+func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) {
+ io.WriteString(out, fmt.Sprintf("%#x\n", pc))
+ io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1))
+}
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index f7f9dbf86..69eeb7528 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -36,3 +36,14 @@ package cpuid
// On arm64, features are numbered according to the ELF HWCAP definition.
// arch/arm64/include/uapi/asm/hwcap.h
type Feature int
+
+// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
+// subset of the host feature set.
+type ErrIncompatible struct {
+ message string
+}
+
+// Error implements error.
+func (e ErrIncompatible) Error() string {
+ return e.message
+}
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 17a89c00d..392711e8f 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -681,17 +681,6 @@ func (fs *FeatureSet) Intel() bool {
return fs.VendorID == intelVendorID
}
-// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
-// subset of the host feature set.
-type ErrIncompatible struct {
- message string
-}
-
-// Error implements error.
-func (e ErrIncompatible) Error() string {
- return e.message
-}
-
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
func (fs *FeatureSet) CheckHostCompatible() error {
hfs := HostFeatureSet()
diff --git a/pkg/crypto/BUILD b/pkg/crypto/BUILD
new file mode 100644
index 000000000..08fa772ca
--- /dev/null
+++ b/pkg/crypto/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "crypto",
+ srcs = [
+ "crypto.go",
+ "crypto_stdlib.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/sleep/empty.s b/pkg/crypto/crypto.go
index fb37360ac..b26b55d37 100644
--- a/pkg/sleep/empty.s
+++ b/pkg/crypto/crypto.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,4 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Empty assembly file so empty func definitions work.
+// Package crypto wraps crypto primitives.
+package crypto
diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go
new file mode 100644
index 000000000..74a55a123
--- /dev/null
+++ b/pkg/crypto/crypto_stdlib.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package crypto
+
+import (
+ "crypto/ecdsa"
+ "crypto/sha512"
+ "math/big"
+)
+
+// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the
+// public key, pub. Its return value records whether the signature is valid.
+func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool {
+ return ecdsa.Verify(pub, hash, r, s)
+}
+
+// SumSha384 returns the SHA384 checksum of the data.
+func SumSha384(data []byte) (sum384 [sha512.Size384]byte) {
+ return sha512.Sum384(data)
+}
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index aa8e4e1f3..cc31d0175 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -11,7 +11,8 @@ go_library(
"futex_linux.go",
"io.go",
"packet_window_allocator.go",
- "packet_window_mmap.go",
+ "packet_window_mmap_amd64.go",
+ "packet_window_mmap_arm64.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap_amd64.go
index 869183b11..869183b11 100644
--- a/pkg/flipcall/packet_window_mmap.go
+++ b/pkg/flipcall/packet_window_mmap_amd64.go
diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/flipcall/packet_window_mmap_arm64.go
index 19d6b0b15..b9c9c44f6 100644
--- a/pkg/syncevent/waiter_asm_unsafe.go
+++ b/pkg/flipcall/packet_window_mmap_arm64.go
@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 arm64
+// +build arm64
-package syncevent
+package flipcall
import (
- "unsafe"
+ "syscall"
)
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD
index d855b702c..08832a8ae 100644
--- a/pkg/goid/BUILD
+++ b/pkg/goid/BUILD
@@ -9,6 +9,7 @@ go_library(
"goid_amd64.s",
"goid_arm64.s",
],
+ stateify = False,
visibility = ["//visibility:public"],
)
diff --git a/pkg/log/json.go b/pkg/log/json.go
index bdf9d691e..8c52dcc87 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -27,8 +27,8 @@ type jsonLog struct {
}
// MarshalJSON implements json.Marshaler.MarashalJSON.
-func (lv Level) MarshalJSON() ([]byte, error) {
- switch lv {
+func (l Level) MarshalJSON() ([]byte, error) {
+ switch l {
case Warning:
return []byte(`"warning"`), nil
case Info:
@@ -36,20 +36,20 @@ func (lv Level) MarshalJSON() ([]byte, error) {
case Debug:
return []byte(`"debug"`), nil
default:
- return nil, fmt.Errorf("unknown level %v", lv)
+ return nil, fmt.Errorf("unknown level %v", l)
}
}
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal
// from both string names and integers.
-func (lv *Level) UnmarshalJSON(b []byte) error {
+func (l *Level) UnmarshalJSON(b []byte) error {
switch s := string(b); s {
case "0", `"warning"`:
- *lv = Warning
+ *l = Warning
case "1", `"info"`:
- *lv = Info
+ *l = Info
case "2", `"debug"`:
- *lv = Debug
+ *l = Debug
default:
return fmt.Errorf("unknown level %q", s)
}
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 37e0605ad..2e3408357 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -356,7 +356,7 @@ func CopyStandardLogTo(l Level) error {
case Warning:
f = Warningf
default:
- return fmt.Errorf("Unknown log level %v", l)
+ return fmt.Errorf("unknown log level %v", l)
}
stdlog.SetOutput(linewriter.NewWriter(func(p []byte) {
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 6acee90ef..aea7dde38 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -350,9 +350,13 @@ type VerifyParams struct {
// For verifyMetadata, params.data is not needed. It only accesses params.tree
// for the raw root hash.
func verifyMetadata(params *VerifyParams, layout *Layout) error {
- root := make([]byte, layout.digestSize)
- if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
- return fmt.Errorf("failed to read root hash: %w", err)
+ var root []byte
+ // Only read the root hash if we expect that the Merkle tree file is non-empty.
+ if params.Size != 0 {
+ root = make([]byte, layout.digestSize)
+ if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
+ return fmt.Errorf("failed to read root hash: %w", err)
+ }
}
descriptor := VerityDescriptor{
Name: params.Name,
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 71e944c30..eadea390a 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -570,6 +570,8 @@ func (c *Client) Version() uint32 {
func (c *Client) Close() {
// unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
// been called (by c.watch()).
- c.socket.Shutdown()
+ if err := c.socket.Shutdown(); err != nil {
+ log.Warningf("Socket.Shutdown() failed (FD: %d): %v", c.socket.FD(), err)
+ }
c.closedWg.Wait()
}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index 28fe081d6..8b46a2987 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -478,28 +478,23 @@ func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) {
}
// Write implements part of the io.ReadWriter interface.
+//
+// Note that this may return a short write with a nil error. This violates the
+// contract of io.Writer, but is more consistent with gVisor's pattern of
+// returning errors that correspond to Linux errnos. Since short writes without
+// error are common in Linux, returning a nil error is appropriate.
func (r *ReadWriterFile) Write(p []byte) (int, error) {
n, err := r.File.WriteAt(p, r.Offset)
r.Offset += uint64(n)
- if err != nil {
- return n, err
- }
- if n < len(p) {
- return n, io.ErrShortWrite
- }
- return n, nil
+ return n, err
}
// WriteAt implements the io.WriteAt interface.
+//
+// Note that this may return a short write with a nil error. This violates the
+// contract of io.WriterAt. See comment on Write for justification.
func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) {
- n, err := r.File.WriteAt(p, uint64(offset))
- if err != nil {
- return n, err
- }
- if n < len(p) {
- return n, io.ErrShortWrite
- }
- return n, nil
+ return r.File.WriteAt(p, uint64(offset))
}
// Rename implements File.Rename.
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index abd237f46..81ceb37c5 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -296,25 +296,6 @@ func (t *Tlopen) handle(cs *connState) message {
}
defer ref.DecRef()
- ref.openedMu.Lock()
- defer ref.openedMu.Unlock()
-
- // Has it been opened already?
- if ref.opened || !CanOpen(ref.mode) {
- return newErr(syscall.EINVAL)
- }
-
- if ref.mode.IsDir() {
- // Directory must be opened ReadOnly.
- if t.Flags&OpenFlagsModeMask != ReadOnly {
- return newErr(syscall.EISDIR)
- }
- // Directory not truncatable.
- if t.Flags&OpenTruncate != 0 {
- return newErr(syscall.EISDIR)
- }
- }
-
var (
qid QID
ioUnit uint32
@@ -326,6 +307,22 @@ func (t *Tlopen) handle(cs *connState) message {
return syscall.EINVAL
}
+ // Has it been opened already?
+ if ref.opened || !CanOpen(ref.mode) {
+ return syscall.EINVAL
+ }
+
+ if ref.mode.IsDir() {
+ // Directory must be opened ReadOnly.
+ if t.Flags&OpenFlagsModeMask != ReadOnly {
+ return syscall.EISDIR
+ }
+ // Directory not truncatable.
+ if t.Flags&OpenTruncate != 0 {
+ return syscall.EISDIR
+ }
+ }
+
osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
return err
}); err != nil {
@@ -366,7 +363,7 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -437,7 +434,7 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -476,7 +473,7 @@ func (t *Tlink) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -518,7 +515,7 @@ func (t *Trenameat) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -561,7 +558,7 @@ func (t *Tunlinkat) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -701,13 +698,12 @@ func (t *Tread) handle(cs *connState) message {
)
if err := ref.safelyRead(func() (err error) {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be read? Check permissions.
- if openFlags&OpenFlagsModeMask == WriteOnly {
+ if ref.openFlags&OpenFlagsModeMask == WriteOnly {
return syscall.EPERM
}
@@ -731,13 +727,12 @@ func (t *Twrite) handle(cs *connState) message {
var n int
if err := ref.safelyRead(func() (err error) {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be written? Check permissions.
- if openFlags&OpenFlagsModeMask == ReadOnly {
+ if ref.openFlags&OpenFlagsModeMask == ReadOnly {
return syscall.EPERM
}
@@ -778,7 +773,7 @@ func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -820,7 +815,7 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -898,13 +893,12 @@ func (t *Tallocate) handle(cs *connState) message {
if err := ref.safelyWrite(func() error {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be written? Check permissions.
- if openFlags&OpenFlagsModeMask == ReadOnly {
+ if ref.openFlags&OpenFlagsModeMask == ReadOnly {
return syscall.EBADF
}
@@ -1049,8 +1043,8 @@ func (t *Treaddir) handle(cs *connState) message {
return syscall.EINVAL
}
- // Has it been opened already?
- if _, opened := ref.OpenFlags(); !opened {
+ // Has it been opened yet?
+ if !ref.opened {
return syscall.EINVAL
}
@@ -1076,8 +1070,8 @@ func (t *Tfsync) handle(cs *connState) message {
defer ref.DecRef()
if err := ref.safelyRead(func() (err error) {
- // Has it been opened already?
- if _, opened := ref.OpenFlags(); !opened {
+ // Has it been opened yet?
+ if !ref.opened {
return syscall.EINVAL
}
@@ -1185,8 +1179,13 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI
}
// Has it been opened already?
- if _, opened := ref.OpenFlags(); opened {
- err = syscall.EBUSY
+ err = ref.safelyRead(func() (err error) {
+ if ref.opened {
+ return syscall.EBUSY
+ }
+ return nil
+ })
+ if err != nil {
return
}
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 6e605b14c..2e3d427ae 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -678,16 +678,15 @@ func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string
// case.
defer checkDeleted(h, dst)
} else {
+ // If the type is different than the destination, then
+ // we expect the rename to fail. We expect that this
+ // is returned.
+ //
+ // If the file being renamed to itself, this is
+ // technically allowed and a no-op, but all the
+ // triggers will fire.
if !selfRename {
- // If the type is different than the
- // destination, then we expect the rename to
- // fail. We expect ensure that this is
- // returned.
expectedErr = syscall.EINVAL
- } else {
- // This is the file being renamed to itself.
- // This is technically allowed and a no-op, but
- // all the triggers will fire.
}
dst.Close()
}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index 3736f12a3..8c5c434fd 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -134,12 +134,11 @@ type fidRef struct {
// The node above will be closed only when refs reaches zero.
refs int64
- // openedMu protects opened and openFlags.
- openedMu sync.Mutex
-
// opened indicates whether this has been opened already.
//
// This is updated in handlers.go.
+ //
+ // opened is protected by pathNode.opMu or renameMu (for write).
opened bool
// mode is the fidRef's mode from the walk. Only the type bits are
@@ -151,6 +150,8 @@ type fidRef struct {
// openFlags is the mode used in the open.
//
// This is updated in handlers.go.
+ //
+ // openFlags is protected by pathNode.opMu or renameMu (for write).
openFlags OpenFlags
// pathNode is the current pathNode for this FID.
@@ -177,13 +178,6 @@ type fidRef struct {
deleted uint32
}
-// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously.
-func (f *fidRef) OpenFlags() (OpenFlags, bool) {
- f.openedMu.Lock()
- defer f.openedMu.Unlock()
- return f.openFlags, f.opened
-}
-
// IncRef increases the references on a fid.
func (f *fidRef) IncRef() {
atomic.AddInt64(&f.refs, 1)
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index e7406b374..a29f06ddb 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -197,33 +197,33 @@ func BenchmarkSendRecv(b *testing.B) {
for i := 0; i < b.N; i++ {
tag, m, err := recv(server, maximumLength, msgRegistry.get)
if err != nil {
- b.Fatalf("recv got err %v expected nil", err)
+ b.Errorf("recv got err %v expected nil", err)
}
if tag != Tag(1) {
- b.Fatalf("got tag %v expected 1", tag)
+ b.Errorf("got tag %v expected 1", tag)
}
if _, ok := m.(*Rflush); !ok {
- b.Fatalf("got message %T expected *Rflush", m)
+ b.Errorf("got message %T expected *Rflush", m)
}
if err := send(server, Tag(2), &Rflush{}); err != nil {
- b.Fatalf("send got err %v expected nil", err)
+ b.Errorf("send got err %v expected nil", err)
}
}
}()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := send(client, Tag(1), &Rflush{}); err != nil {
- b.Fatalf("send got err %v expected nil", err)
+ b.Errorf("send got err %v expected nil", err)
}
tag, m, err := recv(client, maximumLength, msgRegistry.get)
if err != nil {
- b.Fatalf("recv got err %v expected nil", err)
+ b.Errorf("recv got err %v expected nil", err)
}
if tag != Tag(2) {
- b.Fatalf("got tag %v expected 2", tag)
+ b.Errorf("got tag %v expected 2", tag)
}
if _, ok := m.(*Rflush); !ok {
- b.Fatalf("got message %v expected *Rflush", m)
+ b.Errorf("got message %v expected *Rflush", m)
}
}
}
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
index a1b2e0cfe..54e825b28 100644
--- a/pkg/pool/pool.go
+++ b/pkg/pool/pool.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package pool provides a trivial integer pool.
package pool
import (
diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD
index bfa1daa10..0377c0876 100644
--- a/pkg/refsvfs2/BUILD
+++ b/pkg/refsvfs2/BUILD
@@ -9,7 +9,7 @@ go_template(
"refs_template.go",
],
opt_consts = [
- "logTrace",
+ "enableLogging",
],
types = [
"T",
diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go
index f64b6c6ae..3fbc91aa5 100644
--- a/pkg/refsvfs2/refs_template.go
+++ b/pkg/refsvfs2/refs_template.go
@@ -74,11 +74,6 @@ func (r *Refs) LogRefs() bool {
return enableLogging
}
-// EnableLeakCheck enables reference leak checking on r.
-func (r *Refs) EnableLeakCheck() {
- refsvfs2.Register(r)
-}
-
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
func (r *Refs) ReadRefs() int64 {
@@ -136,7 +131,7 @@ func (r *Refs) TryIncRef() bool {
func (r *Refs) DecRef(destroy func()) {
v := atomic.AddInt64(&r.refCount, -1)
if enableLogging {
- refsvfs2.LogDecRef(r, v+1)
+ refsvfs2.LogDecRef(r, v)
}
switch {
case v < 0:
@@ -153,6 +148,6 @@ func (r *Refs) DecRef(destroy func()) {
func (r *Refs) afterLoad() {
if r.ReadRefs() > 0 {
- r.EnableLeakCheck()
+ refsvfs2.Register(r)
}
}
diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go
index e7fd30743..7857f5853 100644
--- a/pkg/safemem/block_unsafe.go
+++ b/pkg/safemem/block_unsafe.go
@@ -68,29 +68,29 @@ func blockFromSlice(slice []byte, needSafecopy bool) Block {
}
}
-// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is
+// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+length), which is
// safe to access without safecopy.
//
-// Preconditions: ptr+len does not overflow.
-func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block {
- return blockFromPointer(ptr, len, false)
+// Preconditions: ptr+length does not overflow.
+func BlockFromSafePointer(ptr unsafe.Pointer, length int) Block {
+ return blockFromPointer(ptr, length, false)
}
// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which
// is not safe to access without safecopy.
//
// Preconditions: ptr+len does not overflow.
-func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block {
- return blockFromPointer(ptr, len, true)
+func BlockFromUnsafePointer(ptr unsafe.Pointer, length int) Block {
+ return blockFromPointer(ptr, length, true)
}
-func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block {
- if uptr := uintptr(ptr); uptr+uintptr(len) < uptr {
- panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len))
+func blockFromPointer(ptr unsafe.Pointer, length int, needSafecopy bool) Block {
+ if uptr := uintptr(ptr); uptr+uintptr(length) < uptr {
+ panic(fmt.Sprintf("ptr %#x + len %#x overflows", uptr, length))
}
return Block{
start: ptr,
- length: len,
+ length: length,
needSafecopy: needSafecopy,
}
}
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index 752e2dc32..ec17ebc4d 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -79,7 +79,7 @@ func Install(rules SyscallRules) error {
// Perform the actual installation.
if errno := SetFilter(instrs); errno != 0 {
- return fmt.Errorf("Failed to set filter: %v", errno)
+ return fmt.Errorf("failed to set filter: %v", errno)
}
log.Infof("Seccomp filters installed.")
diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go
index 7cd895cc7..652c010da 100644
--- a/pkg/segment/test/set_functions.go
+++ b/pkg/segment/test/set_functions.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package segment is a test package.
package segment
type setFunctions struct{}
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index d75d665ae..dd2effdf9 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint {
func (a SyscallArgument) ModeT() uint {
return uint(uint16(a.Value))
}
+
+// ErrFloatingPoint indicates a failed restore due to unusable floating point
+// state.
+type ErrFloatingPoint struct {
+ // supported is the supported floating point state.
+ supported uint64
+
+ // saved is the saved floating point state.
+ saved uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrFloatingPoint) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 19ce99d25..840e53d33 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -17,27 +17,10 @@
package arch
import (
- "fmt"
-
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/usermem"
)
-// ErrFloatingPoint indicates a failed restore due to unusable floating point
-// state.
-type ErrFloatingPoint struct {
- // supported is the supported floating point state.
- supported uint64
-
- // saved is the saved floating point state.
- saved uint64
-}
-
-// Error returns a sensible description of the restore error.
-func (e ErrFloatingPoint) Error() string {
- return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
-}
-
// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
// and SSE state, so this is the equivalent XSTATE_BV value.
const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
index c9fb55d00..35d2e07c3 100644
--- a/pkg/sentry/arch/signal.go
+++ b/pkg/sentry/arch/signal.go
@@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() {
}
}
-// Pid returns the si_pid field.
-func (s *SignalInfo) Pid() int32 {
+// PID returns the si_pid field.
+func (s *SignalInfo) PID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
}
-// SetPid mutates the si_pid field.
-func (s *SignalInfo) SetPid(val int32) {
+// SetPID mutates the si_pid field.
+func (s *SignalInfo) SetPID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
}
-// Uid returns the si_uid field.
-func (s *SignalInfo) Uid() int32 {
+// UID returns the si_uid field.
+func (s *SignalInfo) UID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
}
-// SetUid mutates the si_uid field.
-func (s *SignalInfo) SetUid(val int32) {
+// SetUID mutates the si_uid field.
+func (s *SignalInfo) SetUID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
}
@@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 {
func (s *SignalInfo) SetArch(val uint32) {
usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
}
+
+// Band returns the si_band field.
+func (s *SignalInfo) Band() int64 {
+ return int64(usermem.ByteOrder.Uint64(s.Fields[0:8]))
+}
+
+// SetBand mutates the si_band field.
+func (s *SignalInfo) SetBand(val int64) {
+ // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
+ // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
+ // `int`. See siginfo.h.
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
+}
+
+// FD returns the si_fd field.
+func (s *SignalInfo) FD() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[8:12])
+}
+
+// SetFD mutates the si_fd field.
+func (s *SignalInfo) SetFD(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], val)
+}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 2bf3c45e1..91b8fb44f 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -193,7 +193,7 @@ func (p *Profile) StopTrace(_, _ *struct{}) error {
defer p.mu.Unlock()
if p.traceFile == nil {
- return errors.New("Execution tracing not started")
+ return errors.New("execution tracing not started")
}
// Similarly to the case above, if tasks have not ended traces, we will
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
index d800f2c85..62eaca965 100644
--- a/pkg/sentry/control/state.go
+++ b/pkg/sentry/control/state.go
@@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error {
Callback: func(err error) {
if err == nil {
log.Infof("Save succeeded: exiting...")
+ s.Kernel.SetSaveSuccess(false /* autosave */)
} else {
log.Warningf("Save failed: exiting...")
s.Kernel.SetSaveError(err)
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
index 314661475..badd5b073 100644
--- a/pkg/sentry/fdimport/fdimport.go
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package fdimport provides the Import function.
package fdimport
import (
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index ea85ab33c..5c3e852e9 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -49,13 +49,13 @@ go_library(
"//pkg/amutex",
"//pkg/context",
"//pkg/log",
- "//pkg/metric",
"//pkg/p9",
"//pkg/refs",
"//pkg/secio",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index ff2fe6712..8e0aa9019 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err
// copyUpBuffers is a buffer pool for copying file content. The buffer
// size is the same used by io.Copy.
-var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }}
+var copyUpBuffers = sync.Pool{
+ New: func() interface{} {
+ b := make([]byte, 8*usermem.PageSize)
+ return &b
+ },
+}
// copyContentsLocked copies the contents of lower to upper. It panics if
// less than size bytes can be copied.
@@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
defer lowerFile.DecRef(ctx)
// Use a buffer pool to minimize allocations.
- buf := copyUpBuffers.Get().([]byte)
+ buf := copyUpBuffers.Get().(*[]byte)
defer copyUpBuffers.Put(buf)
// Transfer the contents.
@@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
// optimizations could be self-defeating. So we leave this as simple as possible.
var offset int64
for {
- nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset)
+ nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset)
if err != nil && err != io.EOF {
return err
}
@@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
}
return nil
}
- nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset)
+ nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset)
if err != nil {
return err
}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index c7a11eec1..e04784db2 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) {
wg.Add(1)
go func(o *overlayTestFile) {
if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil {
- t.Fatalf("failed to copy up: %v", err)
+ t.Errorf("failed to copy up: %v", err)
}
wg.Done()
}(file)
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index 72ea70fcf..57f904801 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -17,13 +17,12 @@ package fs
import (
"math"
"sync/atomic"
- "time"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
@@ -33,28 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-var (
- // RecordWaitTime controls writing metrics for filesystem reads.
- // Enabling this comes at a small CPU cost due to performing two
- // monotonic clock reads per read call.
- //
- // Note that this is only performed in the direct read path, and may
- // not be consistently applied for other forms of reads, such as
- // splice.
- RecordWaitTime = false
-
- reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.")
- readWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
-)
-
-// IncrementWait increments the given wait time metric, if enabled.
-func IncrementWait(m *metric.Uint64Metric, start time.Time) {
- if !RecordWaitTime {
- return
- }
- m.IncrementBy(uint64(time.Since(start)))
-}
-
// FileMaxOffset is the maximum possible file offset.
const FileMaxOffset = math.MaxInt64
@@ -257,22 +234,19 @@ func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error {
//
// Returns syserror.ErrInterrupted if reading was interrupted.
func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) {
- var start time.Time
- if RecordWaitTime {
- start = time.Now()
- }
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+
if !f.mu.Lock(ctx) {
- IncrementWait(readWait, start)
return 0, syserror.ErrInterrupted
}
- reads.Increment()
+ fsmetric.Reads.Increment()
n, err := f.FileOperations.Read(ctx, f, dst, f.offset)
if n > 0 && !f.flags.NonSeekable {
atomic.AddInt64(&f.offset, n)
}
f.mu.Unlock()
- IncrementWait(readWait, start)
return n, err
}
@@ -282,19 +256,16 @@ func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error)
//
// Otherwise same as Readv.
func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
- var start time.Time
- if RecordWaitTime {
- start = time.Now()
- }
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+
if !f.mu.Lock(ctx) {
- IncrementWait(readWait, start)
return 0, syserror.ErrInterrupted
}
- reads.Increment()
+ fsmetric.Reads.Increment()
n, err := f.FileOperations.Read(ctx, f, dst, offset)
f.mu.Unlock()
- IncrementWait(readWait, start)
return n, err
}
diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go
index 8049538f2..ec3d3f96c 100644
--- a/pkg/sentry/fs/filetest/filetest.go
+++ b/pkg/sentry/fs/filetest/filetest.go
@@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File {
// Read just fails the request.
func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Readv not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Read not implemented")
}
// Write just fails the request.
func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Writev not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Write not implemented")
}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index d2dbff268..a020da53b 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -65,7 +65,7 @@ var (
// runs with the lock held for reading. AsyncBarrier will take the lock
// for writing, thus ensuring that all Async work completes before
// AsyncBarrier returns.
- workMu sync.RWMutex
+ workMu sync.CrossGoroutineRWMutex
// asyncError is used to store up to one asynchronous execution error.
asyncError = make(chan error, 1)
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index fea135eea..4c30098cd 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -28,7 +28,6 @@ go_library(
"//pkg/context",
"//pkg/fd",
"//pkg/log",
- "//pkg/metric",
"//pkg/p9",
"//pkg/refs",
"//pkg/safemem",
@@ -38,6 +37,7 @@ go_library(
"//pkg/sentry/fs/fdpipe",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/host",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
index d481baf77..e5579095b 100644
--- a/pkg/sentry/fs/gofer/attr.go
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType {
return fs.BlockDevice
case pattr.Mode.IsSocket():
return fs.Socket
- case pattr.Mode.IsRegular():
- fallthrough
default:
return fs.RegularFile
}
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index c0bc63a32..bb63448cb 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -21,27 +21,17 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
-var (
- opensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a writable+executable file was opened from a gofer.")
- opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.")
- opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.")
- reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
- readWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
- readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.")
- readWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
-)
-
// fileOperations implements fs.FileOperations for a remote file system.
//
// +stateify savable
@@ -101,14 +91,14 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF
}
if flags.Write {
if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil {
- opensWX.Increment()
+ fsmetric.GoferOpensWX.Increment()
log.Warningf("Opened a writable executable: %q", name)
}
}
if handles.Host != nil {
- opensHost.Increment()
+ fsmetric.GoferOpensHost.Increment()
} else {
- opens9P.Increment()
+ fsmetric.GoferOpens9P.Increment()
}
return fs.NewFile(ctx, dirent, flags, f)
}
@@ -278,20 +268,17 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I
// use this function rather than using a defer in Read() to avoid the performance hit of defer.
func (f *fileOperations) incrementReadCounters(start time.Time) {
if f.handles.Host != nil {
- readsHost.Increment()
- fs.IncrementWait(readWaitHost, start)
+ fsmetric.GoferReadsHost.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start)
} else {
- reads9P.Increment()
- fs.IncrementWait(readWait9P, start)
+ fsmetric.GoferReads9P.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start)
}
}
// Read implements fs.FileOperations.Read.
func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- var start time.Time
- if fs.RecordWaitTime {
- start = time.Now()
- }
+ start := fsmetric.StartReadWait()
if fs.IsDir(file.Dirent.Inode.StableAttr) {
// Not all remote file systems enforce this so this client does.
f.incrementReadCounters(start)
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 3a225fd39..9d6fdd08f 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -117,7 +117,7 @@ type inodeFileState struct {
// loading is acquired when the inodeFileState begins an asynchronous
// load. It releases when the load is complete. Callers that require all
// state to be available should call waitForLoad() to ensure that.
- loading sync.Mutex `state:".(struct{})"`
+ loading sync.CrossGoroutineMutex `state:".(struct{})"`
// savedUAttr is only allocated during S/R. It points to the save-time
// unstable attributes and is used to validate restore-time ones.
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index 004910453..9b3d8166a 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -18,9 +18,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
@@ -28,8 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-var opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.")
-
// Inode is a file system object that can be simultaneously referenced by different
// components of the VFS (Dirent, fs.File, etc).
//
@@ -247,7 +245,7 @@ func (i *Inode) GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File,
if i.overlay != nil {
return overlayGetFile(ctx, i.overlay, d, flags)
}
- opens.Increment()
+ fsmetric.Opens.Increment()
return i.InodeOperations.GetFile(ctx, d, flags)
}
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index f8aad2dbd..b998fb75d 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
children := map[string]*fs.Inode{
"hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil),
+ "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
"shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
"shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index aa7199014..b521a86a2 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -15,12 +15,12 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
- "//pkg/metric",
"//pkg/safemem",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index d6c65301c..e04cd608d 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -18,14 +18,13 @@ import (
"fmt"
"io"
"math"
- "time"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -35,13 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-var (
- opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.")
- opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.")
- reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.")
- readWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
-)
-
// fileInodeOperations implements fs.InodeOperations for a regular tmpfs file.
// These files are backed by pages allocated from a platform.Memory, and may be
// directly mapped.
@@ -157,9 +149,9 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare
// GetFile implements fs.InodeOperations.GetFile.
func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
if flags.Write {
- opensW.Increment()
+ fsmetric.TmpfsOpensW.Increment()
} else if flags.Read {
- opensRO.Increment()
+ fsmetric.TmpfsOpensRO.Increment()
}
flags.Pread = true
flags.Pwrite = true
@@ -319,14 +311,12 @@ func (*fileInodeOperations) StatFS(context.Context) (fs.Info, error) {
}
func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- var start time.Time
- if fs.RecordWaitTime {
- start = time.Now()
- }
- reads.Increment()
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start)
+ fsmetric.TmpfsReads.Increment()
+
// Zero length reads for tmpfs are no-ops.
if dst.NumBytes() == 0 {
- fs.IncrementWait(readWait, start)
return 0, nil
}
@@ -343,7 +333,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm
size := f.attr.Size
f.dataMu.RUnlock()
if offset >= size {
- fs.IncrementWait(readWait, start)
return 0, io.EOF
}
@@ -354,7 +343,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm
f.attr.AccessTime = ktime.NowFromContext(ctx)
f.attrMu.Unlock()
}
- fs.IncrementWait(readWait, start)
return n, err
}
diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go
index 1b3459c1d..4ab894965 100644
--- a/pkg/sentry/fsimpl/fuse/connection_control.go
+++ b/pkg/sentry/fsimpl/fuse/connection_control.go
@@ -84,11 +84,7 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
Flags: fuseDefaultInitFlags,
}
- req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
// Since there is no task to block on and FUSE_INIT is the request
// to unblock other requests, use nil.
return conn.CallAsync(nil, req)
diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go
index 91d16c1cf..d8b0d7657 100644
--- a/pkg/sentry/fsimpl/fuse/connection_test.go
+++ b/pkg/sentry/fsimpl/fuse/connection_test.go
@@ -76,10 +76,7 @@ func TestConnectionAbort(t *testing.T) {
var futNormal []*futureResponse
for i := 0; i < int(numRequests); i++ {
- req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
fut, err := conn.callFutureLocked(task, req)
if err != nil {
t.Fatalf("callFutureLocked failed: %v", err)
@@ -105,10 +102,7 @@ func TestConnectionAbort(t *testing.T) {
}
// After abort, Call() should return directly with ENOTCONN.
- req, err := conn.NewRequest(creds, 0, 0, 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, 0, 0, 0, testObj)
_, err = conn.Call(task, req)
if err != syserror.ENOTCONN {
t.Fatalf("Incorrect error code received for Call() after connection aborted")
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index 89c3ef079..1bbe6fdb7 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -363,7 +363,7 @@ func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask {
var ready waiter.EventMask
- if fd.fs.umounted {
+ if fd.fs == nil || fd.fs.umounted {
ready |= waiter.EventErr
return ready & mask
}
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 95c475a65..bb2d0d31a 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -219,10 +219,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con
data: rand.Uint32(),
}
- req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
// Queue up a request.
// Analogous to Call except it doesn't block on the task.
diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go
index 8f220a04b..fcc5d9a2a 100644
--- a/pkg/sentry/fsimpl/fuse/directory.go
+++ b/pkg/sentry/fsimpl/fuse/directory.go
@@ -68,11 +68,7 @@ func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirent
}
// TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS.
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
res, err := fusefs.conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go
index 83f2816b7..e138b11f8 100644
--- a/pkg/sentry/fsimpl/fuse/file.go
+++ b/pkg/sentry/fsimpl/fuse/file.go
@@ -83,12 +83,8 @@ func (fd *fileDescription) Release(ctx context.Context) {
opcode = linux.FUSE_RELEASE
}
kernelTask := kernel.TaskFromContext(ctx)
- // ignoring errors and FUSE server reply is analogous to Linux's behavior.
- req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
- if err != nil {
- // No way to invoke Call() with an errored request.
- return
- }
+ // Ignoring errors and FUSE server reply is analogous to Linux's behavior.
+ req := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
// The reply will be ignored since no callback is defined in asyncCallBack().
conn.CallAsync(kernelTask, req)
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 23e827f90..3af807a21 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -119,7 +119,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
if err != nil {
- return nil, nil, err
+ log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err)
+ return nil, nil, syserror.EINVAL
}
kernelTask := kernel.TaskFromContext(ctx)
@@ -360,12 +361,8 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
in.Flags &= ^uint32(linux.O_TRUNC)
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
- if err != nil {
- return nil, err
- }
-
// Send the request and receive the reply.
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -485,10 +482,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err
return syserror.EINVAL
}
in := linux.FUSEUnlinkIn{Name: name}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
- if err != nil {
- return err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return err
@@ -515,11 +509,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro
task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
in := linux.FUSERmDirIn{Name: name}
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return err
@@ -535,10 +525,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
return nil, syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
- if err != nil {
- return nil, err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -574,10 +561,7 @@ func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context")
return "", syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
- if err != nil {
- return "", err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return "", err
@@ -680,11 +664,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
GetAttrFlags: flags,
Fh: fh,
}
- req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
- if err != nil {
- return linux.FUSEAttr{}, err
- }
-
+ req := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return linux.FUSEAttr{}, err
@@ -803,11 +783,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
UID: opts.Stat.UID,
GID: opts.Stat.GID,
}
- req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
res, err := conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 2d396e84c..23ce91849 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -79,13 +79,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
in.Offset = off + (uint64(pagesRead) << usermem.PageShift)
in.Size = pagesCanRead << usermem.PageShift
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
- if err != nil {
- return nil, 0, err
- }
-
// TODO(gvisor.dev/issue/3247): support async read.
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
res, err := fs.conn.Call(t, req)
if err != nil {
return nil, 0, err
@@ -204,11 +200,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
in.Offset = off + uint64(written)
in.Size = toWrite
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
- if err != nil {
- return 0, err
- }
-
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
req.payload = data[written : written+toWrite]
// TODO(gvisor.dev/issue/3247): support async write.
diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go
index 7fa00569b..41d679358 100644
--- a/pkg/sentry/fsimpl/fuse/request_response.go
+++ b/pkg/sentry/fsimpl/fuse/request_response.go
@@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) {
out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2]))
src = src[2:]
}
+ _ = src // Remove unused warning.
}
// SizeBytes is the size of the payload of the FUSE_INIT response.
@@ -104,7 +105,7 @@ type Request struct {
}
// NewRequest creates a new request that can be sent to the FUSE server.
-func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) *Request {
conn.fd.mu.Lock()
defer conn.fd.mu.Unlock()
conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
@@ -130,7 +131,7 @@ func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint
id: hdr.Unique,
hdr: &hdr,
data: buf,
- }, nil
+ }
}
// futureResponse represents an in-flight request, that may or may not have
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 4c3e9acf8..807b6ed1f 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -59,6 +59,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/hostfd",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 2294c490e..df27554d3 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
@@ -985,14 +986,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
switch d.fileType() {
case linux.S_IFREG:
if !d.fs.opts.regularFilesUseSpecialFileFD {
- if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil {
+ if err := d.ensureSharedHandle(ctx, ats.MayRead(), ats.MayWrite(), trunc); err != nil {
return nil, err
}
- fd := &regularFileFD{}
- fd.LockFD.Init(&d.locks)
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
- AllowDirectIO: true,
- }); err != nil {
+ fd, err := newRegularFileFD(mnt, d, opts.Flags)
+ if err != nil {
return nil, err
}
vfd = &fd.vfsfd
@@ -1019,6 +1017,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
+ if atomic.LoadInt32(&d.readFD) >= 0 {
+ fsmetric.GoferOpensHost.Increment()
+ } else {
+ fsmetric.GoferOpens9P.Increment()
+ }
return &fd.vfsfd, nil
case linux.S_IFLNK:
// Can't open symlinks without O_PATH (which is unimplemented).
@@ -1110,7 +1113,7 @@ retry:
return nil, err
}
}
- fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags)
+ fd, err := newSpecialFileFD(h, mnt, d, opts.Flags)
if err != nil {
h.close(ctx)
return nil, err
@@ -1205,11 +1208,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
// Finally, construct a file description representing the created file.
var childVFSFD *vfs.FileDescription
if useRegularFileFD {
- fd := &regularFileFD{}
- fd.LockFD.Init(&child.locks)
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
- AllowDirectIO: true,
- }); err != nil {
+ fd, err := newRegularFileFD(mnt, child, opts.Flags)
+ if err != nil {
return nil, err
}
childVFSFD = &fd.vfsfd
@@ -1221,7 +1221,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
if fdobj != nil {
h.fd = int32(fdobj.Release())
}
- fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags)
+ fd, err := newSpecialFileFD(h, mnt, child, opts.Flags)
if err != nil {
h.close(ctx)
return nil, err
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 75a836899..3cdb1e659 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -743,7 +743,9 @@ type dentry struct {
// for memory mappings. If mmapFD is -1, no such FD is available, and the
// internal page cache implementation is used for memory mappings instead.
//
- // These fields are protected by handleMu.
+ // These fields are protected by handleMu. readFD, writeFD, and mmapFD are
+ // additionally written using atomic memory operations, allowing them to be
+ // read (albeit racily) with atomic.LoadInt32() without locking handleMu.
//
// readFile and writeFile may or may not represent the same p9.File. Once
// either p9.File transitions from closed (isNil() == true) to open
@@ -1351,16 +1353,11 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
return
}
if refs > 0 {
- if d.cached {
- // This isn't strictly necessary (fs.cachedDentries is permitted to
- // contain dentries with non-zero refs, which are skipped by
- // fs.evictCachedDentryLocked() upon reaching the end of the LRU),
- // but since we are already holding fs.renameMu for writing we may
- // as well.
- d.fs.cachedDentries.Remove(d)
- d.fs.cachedDentriesLen--
- d.cached = false
- }
+ // This isn't strictly necessary (fs.cachedDentries is permitted to
+ // contain dentries with non-zero refs, which are skipped by
+ // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but
+ // since we are already holding fs.renameMu for writing we may as well.
+ d.removeFromCacheLocked()
return
}
// Deleted and invalidated dentries with zero references are no longer
@@ -1369,20 +1366,18 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
if d.isDeleted() {
d.watches.HandleDeletion(ctx)
}
- if d.cached {
- d.fs.cachedDentries.Remove(d)
- d.fs.cachedDentriesLen--
- d.cached = false
- }
+ d.removeFromCacheLocked()
d.destroyLocked(ctx)
return
}
- // If d still has inotify watches and it is not deleted or invalidated, we
- // cannot cache it and allow it to be evicted. Otherwise, we will lose its
- // watches, even if a new dentry is created for the same file in the future.
- // Note that the size of d.watches cannot concurrently transition from zero
- // to non-zero, because adding a watch requires holding a reference on d.
+ // If d still has inotify watches and it is not deleted or invalidated, it
+ // can't be evicted. Otherwise, we will lose its watches, even if a new
+ // dentry is created for the same file in the future. Note that the size of
+ // d.watches cannot concurrently transition from zero to non-zero, because
+ // adding a watch requires holding a reference on d.
if d.watches.Size() > 0 {
+ // As in the refs > 0 case, this is not strictly necessary.
+ d.removeFromCacheLocked()
return
}
@@ -1413,6 +1408,15 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
}
}
+// Preconditions: d.fs.renameMu must be locked for writing.
+func (d *dentry) removeFromCacheLocked() {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+}
+
// Precondition: fs.renameMu must be locked for writing; it may be temporarily
// unlocked.
func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
@@ -1426,12 +1430,10 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
// * fs.cachedDentriesLen != 0.
func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
victim := fs.cachedDentries.Back()
- fs.cachedDentries.Remove(victim)
- fs.cachedDentriesLen--
- victim.cached = false
- // victim.refs may have become non-zero from an earlier path resolution
- // since it was inserted into fs.cachedDentries.
- if atomic.LoadInt64(&victim.refs) == 0 {
+ victim.removeFromCacheLocked()
+ // victim.refs or victim.watches.Size() may have become non-zero from an
+ // earlier path resolution since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 {
if victim.parent != nil {
victim.parent.dirMu.Lock()
if !victim.vfsd.IsDead() {
@@ -1668,7 +1670,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
}
fdsToClose = append(fdsToClose, d.readFD)
invalidateTranslations = true
- d.readFD = h.fd
+ atomic.StoreInt32(&d.readFD, h.fd)
} else {
// Otherwise, we want to avoid invalidating existing
// memmap.Translations (which is expensive); instead, use
@@ -1689,15 +1691,15 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
h.fd = d.readFD
}
} else {
- d.readFD = h.fd
+ atomic.StoreInt32(&d.readFD, h.fd)
}
if d.writeFD != h.fd && d.writeFD >= 0 {
fdsToClose = append(fdsToClose, d.writeFD)
}
- d.writeFD = h.fd
- d.mmapFD = h.fd
+ atomic.StoreInt32(&d.writeFD, h.fd)
+ atomic.StoreInt32(&d.mmapFD, h.fd)
} else if openReadable && d.readFD < 0 {
- d.readFD = h.fd
+ atomic.StoreInt32(&d.readFD, h.fd)
// If the file has not been opened for writing, the new FD may
// be used for read-only memory mappings. If the file was
// previously opened for reading (without an FD), then existing
@@ -1705,10 +1707,10 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// invalidate those mappings.
if d.writeFile.isNil() {
invalidateTranslations = !d.readFile.isNil()
- d.mmapFD = h.fd
+ atomic.StoreInt32(&d.mmapFD, h.fd)
}
} else if openWritable && d.writeFD < 0 {
- d.writeFD = h.fd
+ atomic.StoreInt32(&d.writeFD, h.fd)
if d.readFD >= 0 {
// We have an existing read-only FD, but the file has just
// been opened for writing, so we need to start supporting
@@ -1717,7 +1719,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// writable memory mappings. Switch to using the internal
// page cache.
invalidateTranslations = true
- d.mmapFD = -1
+ atomic.StoreInt32(&d.mmapFD, -1)
}
} else {
// The new FD is not useful.
@@ -1729,7 +1731,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// memory mappings. However, we have no writable host FD. Switch to
// using the internal page cache.
invalidateTranslations = true
- d.mmapFD = -1
+ atomic.StoreInt32(&d.mmapFD, -1)
}
// Switch to new fids.
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 652142ecc..283b220bb 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
@@ -48,6 +49,25 @@ type regularFileFD struct {
off int64
}
+func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, error) {
+ fd := &regularFileFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
+ fsmetric.GoferOpensWX.Increment()
+ }
+ if atomic.LoadInt32(&d.mmapFD) >= 0 {
+ fsmetric.GoferOpensHost.Increment()
+ } else {
+ fsmetric.GoferOpens9P.Increment()
+ }
+ return fd, nil
+}
+
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *regularFileFD) Release(context.Context) {
}
@@ -89,6 +109,18 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ start := fsmetric.StartReadWait()
+ d := fd.dentry()
+ defer func() {
+ if atomic.LoadInt32(&d.readFD) >= 0 {
+ fsmetric.GoferReadsHost.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start)
+ } else {
+ fsmetric.GoferReads9P.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start)
+ }
+ }()
+
if offset < 0 {
return 0, syserror.EINVAL
}
@@ -102,7 +134,6 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// Check for reading at EOF before calling into MM (but not under
// InteropModeShared, which makes d.size unreliable).
- d := fd.dentry()
if d.cachedMetadataAuthoritative() && uint64(offset) >= atomic.LoadUint64(&d.size) {
return 0, io.EOF
}
@@ -647,10 +678,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
// Whether or not we have a host FD, we're not allowed to use it.
return syserror.ENODEV
}
- d.handleMu.RLock()
- haveFD := d.mmapFD >= 0
- d.handleMu.RUnlock()
- if !haveFD {
+ if atomic.LoadInt32(&d.mmapFD) < 0 {
return syserror.ENODEV
}
default:
@@ -668,10 +696,7 @@ func (d *dentry) mayCachePages() bool {
if d.fs.opts.forcePageCache {
return true
}
- d.handleMu.RLock()
- haveFD := d.mmapFD >= 0
- d.handleMu.RUnlock()
- return haveFD
+ return atomic.LoadInt32(&d.mmapFD) >= 0
}
// AddMapping implements memmap.Mappable.AddMapping.
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 625400c0b..089955a96 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -70,7 +71,7 @@ type specialFileFD struct {
buf []byte
}
-func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
+func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) {
ftype := d.fileType()
seekable := ftype == linux.S_IFREG || ftype == linux.S_IFCHR || ftype == linux.S_IFBLK
haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0
@@ -80,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
seekable: seekable,
haveQueue: haveQueue,
}
- fd.LockFD.Init(locks)
+ fd.LockFD.Init(&d.locks)
if haveQueue {
if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil {
return nil, err
@@ -98,6 +99,14 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
d.fs.syncMu.Lock()
d.fs.specialFileFDs[fd] = struct{}{}
d.fs.syncMu.Unlock()
+ if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
+ fsmetric.GoferOpensWX.Increment()
+ }
+ if h.fd >= 0 {
+ fsmetric.GoferOpensHost.Increment()
+ } else {
+ fsmetric.GoferOpens9P.Increment()
+ }
return fd, nil
}
@@ -161,6 +170,17 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ start := fsmetric.StartReadWait()
+ defer func() {
+ if fd.handle.fd >= 0 {
+ fsmetric.GoferReadsHost.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start)
+ } else {
+ fsmetric.GoferReads9P.Increment()
+ fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start)
+ }
+ }()
+
if fd.seekable && offset < 0 {
return 0, syserror.EINVAL
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index c14abcff4..565d723f0 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -286,7 +286,7 @@ func (d *Dentry) cacheLocked(ctx context.Context) {
refs := atomic.LoadInt64(&d.refs)
if refs == -1 {
// Dentry has already been destroyed.
- panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d))
+ return
}
if refs > 0 {
if d.cached {
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 469f3a33d..27b00cf6f 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -16,7 +16,6 @@ package overlay
import (
"fmt"
- "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
return err
}
defer newFD.DecRef(ctx)
- bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
- for {
- readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
- if readErr != nil && readErr != io.EOF {
- cleanupUndoCopyUp()
- return readErr
- }
- total := int64(0)
- for total < readN {
- writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{})
- total += writeN
- if writeErr != nil {
- cleanupUndoCopyUp()
- return writeErr
- }
- }
- if readErr == io.EOF {
- break
- }
+ if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil {
+ cleanupUndoCopyUp()
+ return err
}
d.mapsMu.Lock()
defer d.mapsMu.Unlock()
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 2b89a7a6d..25c785fd4 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript
for e, mask := range fd.lowerWaiters {
fd.cachedFD.EventUnregister(e)
upperFD.EventRegister(e, mask)
- if ready&mask != 0 {
- e.Callback.Callback(e)
+ if m := ready & mask; m != 0 {
+ e.Callback.Callback(e, m)
}
}
}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 5cf8a071a..d4f6a5a9b 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -208,7 +208,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
for _, se := range n.kernel.ListSockets() {
s := se.SockVFS2
if !s.TryIncRef() {
- log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ // Racing with socket destruction, this is ok.
continue
}
if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
@@ -351,7 +351,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
for _, se := range k.ListSockets() {
s := se.SockVFS2
if !s.TryIncRef() {
- log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ // Racing with socket destruction, this is ok.
continue
}
sops, ok := s.Impl().(socket.SocketVFS2)
@@ -516,7 +516,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
for _, se := range d.kernel.ListSockets() {
s := se.SockVFS2
if !s.TryIncRef() {
- log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ // Racing with socket destruction, this is ok.
continue
}
sops, ok := s.Impl().(socket.SocketVFS2)
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 7c7afdcfa..25c407d98 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -44,6 +44,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
+ "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
"shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
"shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index 10f1452ef..246bd87bc 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package signalfd provides basic signalfd file implementations.
package signalfd
import (
@@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index fe520b6fd..09957c2b7 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -67,6 +67,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index e39cd305b..9296db2fb 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -381,6 +382,8 @@ afterTrailingSymlink:
creds := rp.Credentials()
child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode))
parentDir.insertChildLocked(child, name)
+ child.IncRef()
+ defer child.DecRef(ctx)
unlock()
fd, err := child.open(ctx, rp, &opts, true)
if err != nil {
@@ -437,6 +440,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, err
}
}
+ if fd.vfsfd.IsWritable() {
+ fsmetric.TmpfsOpensW.Increment()
+ } else if fd.vfsfd.IsReadable() {
+ fsmetric.TmpfsOpensRO.Increment()
+ }
return &fd.vfsfd, nil
case *directory:
// Can't open directories writably.
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index f8e0cffb0..6255a7c84 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -359,6 +360,10 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ start := fsmetric.StartReadWait()
+ defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start)
+ fsmetric.TmpfsReads.Increment()
+
if offset < 0 {
return 0, syserror.EINVAL
}
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index add5dd48e..a4ad625bb 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -107,8 +107,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de
// Dentries which may have a reference count of zero, and which therefore
// should be dropped once traversal is complete, are appended to ds.
//
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
-// !rp.Done().
+// Preconditions:
+// * fs.renameMu must be locked.
+// * d.dirMu must be locked.
+// * !rp.Done().
func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
if !d.isDir() {
return nil, syserror.ENOTDIR
@@ -158,15 +160,19 @@ afterSymlink:
return child, nil
}
-// verifyChild verifies the hash of child against the already verified hash of
-// the parent to ensure the child is expected. verifyChild triggers a sentry
-// panic if unexpected modifications to the file system are detected. In
-// noCrashOnVerificationFailure mode it returns a syserror instead.
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// verifyChildLocked verifies the hash of child against the already verified
+// hash of the parent to ensure the child is expected. verifyChild triggers a
+// sentry panic if unexpected modifications to the file system are detected. In
+// ErrorOnViolation mode it returns a syserror instead.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * d.dirMu must be locked.
+//
// TODO(b/166474175): Investigate all possible errors returned in this
// function, and make sure we differentiate all errors that indicate unexpected
// modifications to the file system from the ones that are not harmful.
-func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
+func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
// Get the path to the child dentry. This is only used to provide path
@@ -248,7 +254,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: parentMerkleFD,
Ctx: ctx,
}
@@ -268,7 +274,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contain the hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if _, err := merkletree.Verify(&merkletree.VerifyParams{
+ parent.hashMu.RLock()
+ _, err = merkletree.Verify(&merkletree.VerifyParams{
Out: &buf,
File: &fdReader,
Tree: &fdReader,
@@ -284,21 +291,27 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
Expected: parent.hash,
DataAndTreeInSameFile: true,
- }); err != nil && err != io.EOF {
+ })
+ parent.hashMu.RUnlock()
+ if err != nil && err != io.EOF {
return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
+ child.hashMu.Lock()
if len(child.hash) == 0 {
child.hash = buf.Bytes()
}
+ child.hashMu.Unlock()
return child, nil
}
-// verifyStatAndChildren verifies the stat and children names against the
+// verifyStatAndChildrenLocked verifies the stat and children names against the
// verified hash. The mode/uid/gid and childrenNames of the file is cached
// after verified.
-func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat linux.Statx) error {
+//
+// Preconditions: d.dirMu must be locked.
+func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry, stat linux.Statx) error {
vfsObj := fs.vfsfs.VirtualFilesystem()
// Get the path to the child dentry. This is only used to provide path
@@ -384,12 +397,13 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
}
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: fd,
Ctx: ctx,
}
var buf bytes.Buffer
+ d.hashMu.RLock()
params := &merkletree.VerifyParams{
Out: &buf,
Tree: &fdReader,
@@ -407,6 +421,7 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
Expected: d.hash,
DataAndTreeInSameFile: false,
}
+ d.hashMu.RUnlock()
if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR {
params.DataAndTreeInSameFile = true
}
@@ -421,7 +436,9 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
return nil
}
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// Preconditions:
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
// If verity is enabled on child, we should check again whether
@@ -470,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// be cached before enabled.
if fs.allowRuntimeEnable {
if parent.verityEnabled() {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil {
return nil, err
}
}
@@ -486,7 +503,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
if err != nil {
return nil, err
}
- if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil {
return nil, err
}
}
@@ -506,7 +523,9 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
return child, nil
}
-// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+// Preconditions:
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -597,13 +616,13 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// allowRuntimeEnable mode and the parent directory hasn't been enabled
// yet.
if parent.verityEnabled() {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil {
child.destroyLocked(ctx)
return nil, err
}
}
if child.verityEnabled() {
- if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil {
child.destroyLocked(ctx)
return nil, err
}
@@ -617,7 +636,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// rp.Start().Impl().(*dentry)). It does not check that the returned directory
// is searchable by the provider of rp.
//
-// Preconditions: fs.renameMu must be locked. !rp.Done().
+// Preconditions:
+// * fs.renameMu must be locked.
+// * !rp.Done().
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
@@ -958,11 +979,13 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
+ d.dirMu.Lock()
if d.verityEnabled() {
- if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
return linux.Statx{}, err
}
}
+ d.dirMu.Unlock()
return stat, nil
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 87dabe038..66029c64d 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -19,6 +19,18 @@
// The verity file system is read-only, except for one case: when
// allowRuntimeEnable is true, additional Merkle files can be generated using
// the FS_IOC_ENABLE_VERITY ioctl.
+//
+// Lock order:
+//
+// filesystem.renameMu
+// dentry.dirMu
+// fileDescription.mu
+// filesystem.verityMu
+// dentry.hashMu
+//
+// Locking dentry.dirMu in multiple dentries requires that parent dentries are
+// locked before child dentries, and that filesystem.renameMu is locked to
+// stabilize this relationship.
package verity
import (
@@ -52,6 +64,10 @@ const (
// tree file for "/foo" is "/.merkle.verity.foo".
merklePrefix = ".merkle.verity."
+ // merkleRootPrefix is the prefix of the Merkle tree root file. This
+ // needs to be different from merklePrefix to avoid name collision.
+ merkleRootPrefix = ".merkleroot.verity."
+
// merkleOffsetInParentXattr is the extended attribute name specifying the
// offset of the child hash in its parent's Merkle tree.
merkleOffsetInParentXattr = "user.merkle.offset"
@@ -76,13 +92,8 @@ const (
)
var (
- // noCrashOnVerificationFailure indicates whether the sandbox should panic
- // whenever verification fails. If true, an error is returned instead of
- // panicking. This should only be set for tests.
- //
- // TODO(b/165661693): Decide whether to panic or return error based on this
- // flag.
- noCrashOnVerificationFailure bool
+ // action specifies the action towards detected violation.
+ action ViolationAction
// verityMu synchronizes concurrent operations that enable verity and perform
// verification checks.
@@ -93,6 +104,18 @@ var (
// content.
type HashAlgorithm int
+// ViolationAction is a type specifying the action when an integrity violation
+// is detected.
+type ViolationAction int
+
+const (
+ // PanicOnViolation terminates the sentry on detected violation.
+ PanicOnViolation ViolationAction = 0
+ // ErrorOnViolation returns an error from the violating system call on
+ // detected violation.
+ ErrorOnViolation = 1
+)
+
// Currently supported hashing algorithms include SHA256 and SHA512.
const (
SHA256 HashAlgorithm = iota
@@ -187,10 +210,8 @@ type InternalFilesystemOptions struct {
// system wrapped by verity file system.
LowerGetFSOptions vfs.GetFilesystemOptions
- // NoCrashOnVerificationFailure indicates whether the sandbox should
- // panic whenever verification fails. If true, an error is returned
- // instead of panicking. This should only be set for tests.
- NoCrashOnVerificationFailure bool
+ // Action specifies the action on an integrity violation.
+ Action ViolationAction
}
// Name implements vfs.FilesystemType.Name.
@@ -202,10 +223,10 @@ func (FilesystemType) Name() string {
func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
-// unexpected modification to the file system is detected. In
-// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic.
+// unexpected modification to the file system is detected. In ErrorOnViolation
+// mode, it returns EIO, otherwise it panic.
func alertIntegrityViolation(msg string) error {
- if noCrashOnVerificationFailure {
+ if action == ErrorOnViolation {
return syserror.EIO
}
panic(msg)
@@ -218,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
return nil, nil, syserror.EINVAL
}
- noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure
+ action = iopts.Action
// Mount the lower file system. The lower file system is wrapped inside
// verity, and should not be exposed or connected.
@@ -246,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
lowerVD.IncRef()
d.lowerVD = lowerVD
- rootMerkleName := merklePrefix + iopts.RootMerkleFileName
+ rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName
lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
@@ -372,12 +393,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err))
}
- if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
return nil, nil, err
}
}
+ d.hashMu.Lock()
copy(d.hash, iopts.RootHash)
+ d.hashMu.Unlock()
d.vfsd.Init(d)
fs.rootDentry = d
@@ -402,7 +425,8 @@ type dentry struct {
fs *filesystem
// mode, uid, gid and size are the file mode, owner, group, and size of
- // the file in the underlying file system.
+ // the file in the underlying file system. They are set when a dentry
+ // is initialized, and never modified.
mode uint32
uid uint32
gid uint32
@@ -425,18 +449,22 @@ type dentry struct {
// childrenNames stores the name of all children of the dentry. This is
// used by verity to check whether a child is expected. This is only
- // populated by enableVerity.
+ // populated by enableVerity. childrenNames is also protected by dirMu.
childrenNames map[string]struct{}
- // lowerVD is the VirtualDentry in the underlying file system.
+ // lowerVD is the VirtualDentry in the underlying file system. It is
+ // never modified after initialized.
lowerVD vfs.VirtualDentry
// lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree
- // in the underlying file system.
+ // in the underlying file system. It is never modified after
+ // initialized.
lowerMerkleVD vfs.VirtualDentry
- // hash is the calculated hash for the current file or directory.
- hash []byte
+ // hash is the calculated hash for the current file or directory. hash
+ // is protected by hashMu.
+ hashMu sync.RWMutex `state:"nosave"`
+ hash []byte
}
// newDentry creates a new dentry representing the given verity file. The
@@ -519,7 +547,9 @@ func (d *dentry) checkDropLocked(ctx context.Context) {
// destroyLocked destroys the dentry.
//
-// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
+// Preconditions:
+// * d.fs.renameMu must be locked for writing.
+// * d.refs == 0.
func (d *dentry) destroyLocked(ctx context.Context) {
switch atomic.LoadInt64(&d.refs) {
case 0:
@@ -599,6 +629,8 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes)
// mode, it returns true if the target has been enabled with
// ioctl(FS_IOC_ENABLE_VERITY).
func (d *dentry) verityEnabled() bool {
+ d.hashMu.RLock()
+ defer d.hashMu.RUnlock()
return !d.fs.allowRuntimeEnable || len(d.hash) != 0
}
@@ -678,11 +710,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
if err != nil {
return linux.Statx{}, err
}
+ fd.d.dirMu.Lock()
if fd.d.verityEnabled() {
- if err := fd.d.fs.verifyStatAndChildren(ctx, fd.d, stat); err != nil {
+ if err := fd.d.fs.verifyStatAndChildrenLocked(ctx, fd.d, stat); err != nil {
return linux.Statx{}, err
}
}
+ fd.d.dirMu.Unlock()
return stat, nil
}
@@ -718,22 +752,24 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32)
return offset, nil
}
-// generateMerkle generates a Merkle tree file for fd. If fd points to a file
-// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
-// of the generated Merkle tree and the data size is returned. If fd points to
-// a regular file, the data is the content of the file. If fd points to a
-// directory, the data is all hahes of its children, written to the Merkle tree
-// file.
-func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) {
- fdReader := vfs.FileReadWriteSeeker{
+// generateMerkleLocked generates a Merkle tree file for fd. If fd points to a
+// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The
+// hash of the generated Merkle tree and the data size is returned. If fd
+// points to a regular file, the data is the content of the file. If fd points
+// to a directory, the data is all hashes of its children, written to the Merkle
+// tree file.
+//
+// Preconditions: fd.d.fs.verityMu must be locked.
+func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) {
+ fdReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
- merkleWriter := vfs.FileReadWriteSeeker{
+ merkleWriter := FileReadWriteSeeker{
FD: fd.merkleWriter,
Ctx: ctx,
}
@@ -793,11 +829,14 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
return hash, uint64(params.Size), err
}
-// recordChildren writes the names of fd's children into the corresponding
-// Merkle tree file, and saves the offset/size of the map into xattrs.
+// recordChildrenLocked writes the names of fd's children into the
+// corresponding Merkle tree file, and saves the offset/size of the map into
+// xattrs.
//
-// Preconditions: fd.d.isDir() == true
-func (fd *fileDescription) recordChildren(ctx context.Context) error {
+// Preconditions:
+// * fd.d.fs.verityMu must be locked.
+// * fd.d.isDir() == true.
+func (fd *fileDescription) recordChildrenLocked(ctx context.Context) error {
// Record the children names in the Merkle tree file.
childrenNames, err := json.Marshal(fd.d.childrenNames)
if err != nil {
@@ -847,7 +886,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
- hash, dataSize, err := fd.generateMerkle(ctx)
+ hash, dataSize, err := fd.generateMerkleLocked(ctx)
if err != nil {
return 0, err
}
@@ -888,11 +927,13 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
}
if fd.d.isDir() {
- if err := fd.recordChildren(ctx); err != nil {
+ if err := fd.recordChildrenLocked(ctx); err != nil {
return 0, err
}
}
- fd.d.hash = append(fd.d.hash, hash...)
+ fd.d.hashMu.Lock()
+ fd.d.hash = hash
+ fd.d.hashMu.Unlock()
return 0, nil
}
@@ -904,6 +945,9 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
}
var metadata linux.DigestMetadata
+ fd.d.hashMu.RLock()
+ defer fd.d.hashMu.RUnlock()
+
// If allowRuntimeEnable is true, an empty fd.d.hash indicates that
// verity is not enabled for the file. If allowRuntimeEnable is false,
// this is an integrity violation because all files should have verity
@@ -940,11 +984,13 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) {
f := int32(0)
+ fd.d.hashMu.RLock()
// All enabled files should store a hash. This flag is not settable via
// FS_IOC_SETFLAGS.
if len(fd.d.hash) != 0 {
f |= linux.FS_VERITY_FL
}
+ fd.d.hashMu.RUnlock()
t := kernel.TaskFromContext(ctx)
if t == nil {
@@ -1013,16 +1059,17 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
- dataReader := vfs.FileReadWriteSeeker{
+ dataReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
+ fd.d.hashMu.RLock()
n, err := merkletree.Verify(&merkletree.VerifyParams{
Out: dst.Writer(ctx),
File: &dataReader,
@@ -1040,6 +1087,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Expected: fd.d.hash,
DataAndTreeInSameFile: false,
})
+ fd.d.hashMu.RUnlock()
if err != nil {
return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
@@ -1065,3 +1113,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t
func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence)
}
+
+// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
+// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
+type FileReadWriteSeeker struct {
+ FD *vfs.FileDescription
+ Ctx context.Context
+ ROpts vfs.ReadOptions
+ WOpts vfs.WriteOptions
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
+ return int(n), err
+}
+
+// Read implements io.ReadWriteSeeker.Read.
+func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
+ return int(n), err
+}
+
+// Seek implements io.ReadWriteSeeker.Seek.
+func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
+ return f.FD.Seek(f.Ctx, offset, int32(whence))
+}
+
+// WriteAt implements io.WriterAt.WriteAt.
+func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
+ return int(n), err
+}
+
+// Write implements io.ReadWriteSeeker.Write.
+func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
+ buf := usermem.BytesIOSequence(p)
+ n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
+ return int(n), err
+}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 7196e74eb..30d8b4355 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -35,16 +35,39 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// rootMerkleFilename is the name of the root Merkle tree file.
-const rootMerkleFilename = "root.verity"
+const (
+ // rootMerkleFilename is the name of the root Merkle tree file.
+ rootMerkleFilename = "root.verity"
+ // maxDataSize is the maximum data size of a test file.
+ maxDataSize = 100000
+)
+
+var hashAlgs = []HashAlgorithm{SHA256, SHA512}
-// maxDataSize is the maximum data size written to the file for test.
-const maxDataSize = 100000
+func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry {
+ t.Helper()
+ d, ok := vd.Dentry().Impl().(*dentry)
+ if !ok {
+ t.Fatalf("can't assert %T as a *dentry", vd)
+ }
+ return d
+}
+
+// dentryFromFD returns the dentry corresponding to fd.
+func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry {
+ t.Helper()
+ f, ok := fd.Impl().(*fileDescription)
+ if !ok {
+ t.Fatalf("can't assert %T as a *fileDescription", fd)
+ }
+ return f.d
+}
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+ t.Helper()
k, err := testutil.Boot()
if err != nil {
t.Fatalf("testutil.Boot: %v", err)
@@ -69,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
InternalData: InternalFilesystemOptions{
- RootMerkleFileName: rootMerkleFilename,
- LowerName: "tmpfs",
- Alg: hashAlg,
- AllowRuntimeEnable: true,
- NoCrashOnVerificationFailure: true,
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ Alg: hashAlg,
+ AllowRuntimeEnable: true,
+ Action: ErrorOnViolation,
},
},
})
@@ -92,7 +115,6 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
t.Fatalf("testutil.CreateTask: %v", err)
}
- t.Helper()
t.Cleanup(func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
@@ -100,21 +122,97 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
return vfsObj, root, task, nil
}
-// newFileFD creates a new file in the verity mount, and returns the FD. The FD
-// points to a file that has random data generated.
-func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
- creds := auth.CredentialsFromContext(ctx)
- lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+// openVerityAt opens a verity file.
+//
+// TODO(chongc): release reference from opening the file when done.
+func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
+ return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: vd,
+ Start: vd,
+ Path: fspath.Parse(path),
+ }, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: mode,
+ })
+}
- // Create the file in the underlying file system.
- lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(filePath),
+// openLowerAt opens the file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
+func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
+ return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(path),
+ }, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: mode,
+ })
+}
+
+// openLowerMerkleAt opens the Merkle file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
+func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
+ return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
}, &vfs.OpenOptions{
- Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
- Mode: linux.ModeRegular | mode,
+ Flags: flags,
+ Mode: mode,
+ })
+}
+
+// unlinkLowerAt deletes the file in the underlying file system.
+func (d *dentry) unlinkLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error {
+ return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(path),
})
+}
+
+// unlinkLowerMerkleAt deletes the Merkle file in the underlying file system.
+func (d *dentry) unlinkLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error {
+ return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(merklePrefix + path),
+ })
+}
+
+// renameLowerAt renames file name to newName in the underlying file system.
+func (d *dentry) renameLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error {
+ return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(name),
+ }, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(newName),
+ }, &vfs.RenameOptions{})
+}
+
+// renameLowerMerkleAt renames Merkle file name to newName in the underlying
+// file system.
+func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error {
+ return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(merklePrefix + name),
+ }, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(merklePrefix + newName),
+ }, &vfs.RenameOptions{})
+}
+
+// newFileFD creates a new file in the verity mount, and returns the FD. The FD
+// points to a file that has random data generated.
+func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
+ // Create the file in the underlying file system.
+ lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
if err != nil {
return nil, 0, err
}
@@ -137,20 +235,24 @@ func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.Virt
lowerFD.DecRef(ctx)
// Now open the verity file descriptor.
- fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filePath),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular | mode,
- })
+ fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode)
return fd, dataSize, err
}
-// corruptRandomBit randomly flips a bit in the file represented by fd.
-func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
- // Flip a random bit in the underlying file.
+// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD.
+func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) {
+ // Create the file in the underlying file system.
+ _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
+ if err != nil {
+ return nil, err
+ }
+ // Now open the verity file descriptor.
+ fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode)
+ return fd, err
+}
+
+// flipRandomBit randomly flips a bit in the file represented by fd.
+func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
randomPos := int64(rand.Intn(size))
byteToModify := make([]byte, 1)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
@@ -163,7 +265,14 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
return nil
}
-var hashAlgs = []HashAlgorithm{SHA256, SHA512}
+func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ t.Helper()
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("enable verity: %v", err)
+ }
+}
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
@@ -175,30 +284,18 @@ func TestOpen(t *testing.T) {
}
filename := "verity-test-file"
- if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
+ if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Ensure that the corresponding Merkle tree file is created.
- lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
+ if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
}
// Ensure the root merkle tree file is created.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + rootMerkleFilename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
+ if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
}
}
@@ -214,17 +311,13 @@ func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
@@ -248,17 +341,13 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
@@ -272,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
}
+// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity
+// file succeeds after enabling verity for it.
+func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-empty-file"
+ fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newEmptyFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ enableVerity(ctx, t, fd)
+
+ var buf []byte
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != 0 {
+ t.Errorf("fd.Read got read length %d, expected 0", n)
+ }
+ }
+}
+
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
@@ -282,27 +401,16 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Ensure reopening the verity enabled file succeeds.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != nil {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("reopen enabled file failed: %v", err)
}
}
@@ -317,43 +425,24 @@ func TestOpenNonexistentFile(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
- parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- })
+ parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Ensure open an unexpected file in the parent directory fails with
// ENOENT rather than verification failure.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename + "abc"),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != syserror.ENOENT {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); err != syserror.ENOENT {
t.Errorf("OpenAt unexpected error: %v", err)
}
}
@@ -368,33 +457,22 @@ func TestPReadModifiedFileFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -415,33 +493,22 @@ func TestReadModifiedFileFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -462,27 +529,16 @@ func TestModifiedMerkleFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerMerkleFD that's read/writable.
- lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
-
- lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerMerkleVD,
- Start: lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -493,14 +549,13 @@ func TestModifiedMerkleFails(t *testing.T) {
t.Errorf("lowerMerkleFD.Stat: %v", err)
}
- if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from a file with modified Merkle tree fails.
buf := make([]byte, size)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- fmt.Println(buf)
t.Fatalf("fd.PRead succeeded with modified Merkle file")
}
}
@@ -517,42 +572,23 @@ func TestModifiedParentMerkleFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
- parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- })
+ parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Open a new lowerMerkleFD that's read/writable.
- parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
-
- parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: parentLowerMerkleVD,
- Start: parentLowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
+ parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -572,21 +608,14 @@ func TestModifiedParentMerkleFails(t *testing.T) {
if err != nil {
t.Fatalf("Failed convert size to int: %v", err)
}
- if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
parentLowerMerkleFD.DecRef(ctx)
// Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err == nil {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err == nil {
t.Errorf("OpenAt file with modified parent Merkle succeeded")
}
}
@@ -602,18 +631,13 @@ func TestUnmodifiedStatSucceeds(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
- // Enable verity on the file and confirms stat succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
+ // Enable verity on the file and confirm that stat succeeds.
+ enableVerity(ctx, t, fd)
if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
t.Errorf("fd.Stat: %v", err)
}
@@ -630,17 +654,13 @@ func TestModifiedStatFails(t *testing.T) {
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
lowerFD := fd.Impl().(*fileDescription).lowerFD
// Change the stat of the underlying file, and check that stat fails.
@@ -663,73 +683,57 @@ func TestModifiedStatFails(t *testing.T) {
// and/or the corresponding Merkle tree file fails with the verity error.
func TestOpenDeletedFileFails(t *testing.T) {
testCases := []struct {
+ name string
// The original file is removed if changeFile is true.
changeFile bool
// The Merkle tree file is removed if changeMerkleFile is true.
changeMerkleFile bool
}{
{
+ name: "FileOnly",
changeFile: true,
changeMerkleFile: false,
},
{
+ name: "MerkleOnly",
changeFile: false,
changeMerkleFile: true,
},
{
+ name: "FileAndMerkle",
changeFile: true,
changeMerkleFile: true,
},
}
for _, tc := range testCases {
- t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) {
+ t.Run(tc.name, func(t *testing.T) {
vfsObj, root, ctx, err := newVerityRoot(t, SHA256)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
- rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
if tc.changeFile {
- if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(filename),
- }); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + filename),
- }); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
// Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != syserror.EIO {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
t.Errorf("got OpenAt error: %v, expected EIO", err)
}
})
@@ -740,82 +744,58 @@ func TestOpenDeletedFileFails(t *testing.T) {
// and/or the corresponding Merkle tree file fails with the verity error.
func TestOpenRenamedFileFails(t *testing.T) {
testCases := []struct {
+ name string
// The original file is renamed if changeFile is true.
changeFile bool
// The Merkle tree file is renamed if changeMerkleFile is true.
changeMerkleFile bool
}{
{
+ name: "FileOnly",
changeFile: true,
changeMerkleFile: false,
},
{
+ name: "MerkleOnly",
changeFile: false,
changeMerkleFile: true,
},
{
+ name: "FileAndMerkle",
changeFile: true,
changeMerkleFile: true,
},
}
for _, tc := range testCases {
- t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) {
+ t.Run(tc.name, func(t *testing.T) {
vfsObj, root, ctx, err := newVerityRoot(t, SHA256)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644)
if err != nil {
t.Fatalf("newFileFD: %v", err)
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
- rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
newFilename := "renamed-test-file"
if tc.changeFile {
- if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(filename),
- }, &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(newFilename),
- }, &vfs.RenameOptions{}); err != nil {
+ if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("RenameAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + newFilename),
- }, &vfs.RenameOptions{}); err != nil {
+ if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
// Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != syserror.EIO {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
t.Errorf("got OpenAt error: %v, expected EIO", err)
}
})
diff --git a/pkg/sentry/fsmetric/BUILD b/pkg/sentry/fsmetric/BUILD
new file mode 100644
index 000000000..4e86fbdd8
--- /dev/null
+++ b/pkg/sentry/fsmetric/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "fsmetric",
+ srcs = ["fsmetric.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/metric"],
+)
diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go
new file mode 100644
index 000000000..7e535b527
--- /dev/null
+++ b/pkg/sentry/fsmetric/fsmetric.go
@@ -0,0 +1,83 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsmetric defines filesystem metrics that are used by both VFS1 and
+// VFS2.
+//
+// TODO(gvisor.dev/issue/1624): Once VFS1 is deleted, inline these metrics into
+// VFS2.
+package fsmetric
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/metric"
+)
+
+// RecordWaitTime enables the ReadWait, GoferReadWait9P, GoferReadWaitHost, and
+// TmpfsReadWait metrics. Enabling this comes at a CPU cost due to performing
+// three clock reads per read call.
+//
+// Note that this is only performed in the direct read path, and may not be
+// consistently applied for other forms of reads, such as splice.
+var RecordWaitTime = false
+
+// Metrics that apply to all filesystems.
+var (
+ Opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.")
+ Reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.")
+ ReadWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
+)
+
+// Metrics that only apply to fs/gofer and fsimpl/gofer.
+var (
+ GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.")
+ GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.")
+ GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.")
+ GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
+ GoferReadWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
+ GoferReadsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.")
+ GoferReadWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
+)
+
+// Metrics that only apply to fs/tmpfs and fsimpl/tmpfs.
+var (
+ TmpfsOpensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.")
+ TmpfsOpensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.")
+ TmpfsReads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.")
+ TmpfsReadWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
+)
+
+// StartReadWait indicates the beginning of a file read.
+func StartReadWait() time.Time {
+ if !RecordWaitTime {
+ return time.Time{}
+ }
+ return time.Now()
+}
+
+// FinishReadWait indicates the end of a file read whose time is accounted by
+// m. start must be the value returned by the corresponding call to
+// StartReadWait.
+//
+// FinishReadWait is marked nosplit for performance since it's often called
+// from defer statements, which prevents it from being inlined
+// (https://github.com/golang/go/issues/38471).
+//go:nosplit
+func FinishReadWait(m *metric.Uint64Metric, start time.Time) {
+ if !RecordWaitTime {
+ return
+ }
+ m.IncrementBy(uint64(time.Since(start).Nanoseconds()))
+}
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 15519f0df..61aeca044 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
//
// Callback is called when one of the files we're polling becomes ready. It
// moves said file to the readyList if it's currently in the waiting list.
-func (p *pollEntry) Callback(*waiter.Entry) {
+func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) {
e := p.epoll
e.listsMu.Lock()
@@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
f.EventRegister(&entry.waiter, entry.mask)
// Check if the file happens to already be in a ready state.
- ready := f.Readiness(entry.mask) & entry.mask
- if ready != 0 {
- entry.Callback(&entry.waiter)
+ if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 {
+ entry.Callback(&entry.waiter, ready)
}
}
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 2b3955598..f855f038b 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -8,11 +8,13 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
+ "//pkg/syserror",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index 153d2cd9b..b66d61c6f 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -17,22 +17,45 @@ package fasync
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
-// New creates a new fs.FileAsync.
-func New() fs.FileAsync {
- return &FileAsync{}
+// Table to convert waiter event masks into si_band siginfo codes.
+// Taken from fs/fcntl.c:band_table.
+var bandTable = map[waiter.EventMask]int64{
+ // POLL_IN
+ waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM,
+ // POLL_OUT
+ waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND,
+ // POLL_ERR
+ waiter.EventErr: linux.EPOLLERR,
+ // POLL_PRI
+ waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND,
+ // POLL_HUP
+ waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR,
}
-// NewVFS2 creates a new vfs.FileAsync.
-func NewVFS2() vfs.FileAsync {
- return &FileAsync{}
+// New returns a function that creates a new fs.FileAsync with the given file
+// descriptor.
+func New(fd int) func() fs.FileAsync {
+ return func() fs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
+}
+
+// NewVFS2 returns a function that creates a new vfs.FileAsync with the given
+// file descriptor.
+func NewVFS2(fd int) func() vfs.FileAsync {
+ return func() vfs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
}
// FileAsync sends signals when the registered file is ready for IO.
@@ -42,6 +65,12 @@ type FileAsync struct {
// e is immutable after first use (which is protected by mu below).
e waiter.Entry
+ // fd is the file descriptor to notify about.
+ // It is immutable, set at allocation time. This matches Linux semantics in
+ // fs/fcntl.c:fasync_helper.
+ // The fd value is passed to the signal recipient in siginfo.si_fd.
+ fd int
+
// regMu protects registeration and unregistration actions on e.
//
// regMu must be held while registration decisions are being made
@@ -56,6 +85,10 @@ type FileAsync struct {
mu sync.Mutex `state:"nosave"`
requester *auth.Credentials
registered bool
+ // signal is the signal to deliver upon I/O being available.
+ // The default value ("zero signal") means the default SIGIO signal will be
+ // delivered.
+ signal linux.Signal
// Only one of the following is allowed to be non-nil.
recipientPG *kernel.ProcessGroup
@@ -64,10 +97,10 @@ type FileAsync struct {
}
// Callback sends a signal.
-func (a *FileAsync) Callback(e *waiter.Entry) {
+func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) {
a.mu.Lock()
+ defer a.mu.Unlock()
if !a.registered {
- a.mu.Unlock()
return
}
t := a.recipientT
@@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) {
}
if t == nil {
// No recipient has been registered.
- a.mu.Unlock()
return
}
c := t.Credentials()
// Logic from sigio_perm in fs/fcntl.c.
- if a.requester.EffectiveKUID == 0 ||
+ permCheck := (a.requester.EffectiveKUID == 0 ||
a.requester.EffectiveKUID == c.SavedKUID ||
a.requester.EffectiveKUID == c.RealKUID ||
a.requester.RealKUID == c.SavedKUID ||
- a.requester.RealKUID == c.RealKUID {
- t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO))
+ a.requester.RealKUID == c.RealKUID)
+ if !permCheck {
+ return
}
- a.mu.Unlock()
+ signalInfo := &arch.SignalInfo{
+ Signo: int32(linux.SIGIO),
+ Code: arch.SignalInfoKernel,
+ }
+ if a.signal != 0 {
+ signalInfo.Signo = int32(a.signal)
+ signalInfo.SetFD(uint32(a.fd))
+ var band int64
+ for m, bandCode := range bandTable {
+ if m&mask != 0 {
+ band |= bandCode
+ }
+ }
+ signalInfo.SetBand(band)
+ }
+ t.SendSignal(signalInfo)
}
// Register sets the file which will be monitored for IO events.
@@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() {
a.recipientTG = nil
a.recipientPG = nil
}
+
+// Signal returns which signal will be sent to the signal recipient.
+// A value of zero means the signal to deliver wasn't customized, which means
+// the default signal (SIGIO) will be delivered.
+func (a *FileAsync) Signal() linux.Signal {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.signal
+}
+
+// SetSignal overrides which signal to send when I/O is available.
+// The default behavior can be reset by specifying signal zero, which means
+// to send SIGIO.
+func (a *FileAsync) SetSignal(signal linux.Signal) error {
+ if signal != 0 && !signal.IsValid() {
+ return syserror.EINVAL
+ }
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.signal = signal
+ return nil
+}
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index 470d8bf83..f17f9c59c 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -121,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2
panic("VFS1 and VFS2 files set")
}
- slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+ slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
// Grow the table as required.
- if last := int32(len(slice)); fd >= last {
+ if last := int32(len(*slicePtr)); fd >= last {
end := fd + 1
if end < 2*last {
end = 2 * last
}
- slice = append(slice, make([]unsafe.Pointer, end-last)...)
- atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
+ newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...)
+ slicePtr = &newSlice
+ atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr))
}
+ slice := *slicePtr
+
var desc *descriptor
if file != nil || fileVFS2 != nil {
desc = &descriptor{
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 2cdcdfc1f..b8627a54f 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -214,9 +214,11 @@ type Kernel struct {
// netlinkPorts manages allocation of netlink socket port IDs.
netlinkPorts *port.Manager
- // saveErr is the error causing the sandbox to exit during save, if
- // any. It is protected by extMu.
- saveErr error `state:"nosave"`
+ // saveStatus is nil if the sandbox has not been saved, errSaved or
+ // errAutoSaved if it has been saved successfully, or the error causing the
+ // sandbox to exit during save.
+ // It is protected by extMu.
+ saveStatus error `state:"nosave"`
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
@@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager {
return k.netlinkPorts
}
-// SaveError returns the sandbox error that caused the kernel to exit during
-// save.
-func (k *Kernel) SaveError() error {
+var (
+ errSaved = errors.New("sandbox has been successfully saved")
+ errAutoSaved = errors.New("sandbox has been successfully auto-saved")
+)
+
+// SaveStatus returns the sandbox save status. If it was saved successfully,
+// autosaved indicates whether save was triggered by autosave. If it was not
+// saved successfully, err indicates the sandbox error that caused the kernel to
+// exit during save.
+func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ switch k.saveStatus {
+ case nil:
+ return false, false, nil
+ case errSaved:
+ return true, false, nil
+ case errAutoSaved:
+ return true, true, nil
+ default:
+ return false, false, k.saveStatus
+ }
+}
+
+// SetSaveSuccess sets the flag indicating that save completed successfully, if
+// no status was already set.
+func (k *Kernel) SetSaveSuccess(autosave bool) {
k.extMu.Lock()
defer k.extMu.Unlock()
- return k.saveErr
+ if k.saveStatus == nil {
+ if autosave {
+ k.saveStatus = errAutoSaved
+ } else {
+ k.saveStatus = errSaved
+ }
+ }
}
// SetSaveError sets the sandbox error that caused the kernel to exit during
@@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error {
func (k *Kernel) SetSaveError(err error) {
k.extMu.Lock()
defer k.extMu.Unlock()
- if k.saveErr == nil {
- k.saveErr = err
+ if k.saveStatus == nil {
+ k.saveStatus = err
}
}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 1abfe2201..cef58a590 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) {
Signo: int32(linux.SIGTRAP),
Code: code,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
if t.beginPtraceStopLocked() {
tracer := t.Tracer()
tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index b99c0bffa..db01e4a97 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -29,17 +29,17 @@ import (
)
const (
- valueMax = 32767 // SEMVMX
+ // Maximum semaphore value.
+ valueMax = linux.SEMVMX
- // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL).
- semaphoresMax = 32000
+ // Maximum number of semaphore sets.
+ setsMax = linux.SEMMNI
- // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI).
- setsMax = 32000
+ // Maximum number of semaphroes in a semaphore set.
+ semsMax = linux.SEMMSL
- // semaphoresTotalMax is "system-wide limit on the number of semaphores"
- // (SEMMNS = SEMMNI*SEMMSL).
- semaphoresTotalMax = 1024000000
+ // Maximum number of semaphores in all semaphroe sets.
+ semsTotalMax = linux.SEMMNS
)
// Registry maintains a set of semaphores that can be found by key or ID.
@@ -52,6 +52,9 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
semaphores map[int32]*Set
lastIDUsed int32
+ // indexes maintains a mapping between a set's index in virtual array and
+ // its identifier.
+ indexes map[int32]int32
}
// Set represents a set of semaphores that can be operated atomically.
@@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
return &Registry{
userNS: userNS,
semaphores: make(map[int32]*Set),
+ indexes: make(map[int32]int32),
}
}
@@ -122,7 +126,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
// be found. If exclusive is true, it fails if a set with the same key already
// exists.
func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
- if nsems < 0 || nsems > semaphoresMax {
+ if nsems < 0 || nsems > semsMax {
return nil, syserror.EINVAL
}
@@ -163,10 +167,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
}
// Apply system limits.
+ //
+ // Map semaphores and map indexes in a registry are of the same size,
+ // check map semaphores only here for the system limit.
if len(r.semaphores) >= setsMax {
return nil, syserror.EINVAL
}
- if r.totalSems() > int(semaphoresTotalMax-nsems) {
+ if r.totalSems() > int(semsTotalMax-nsems) {
return nil, syserror.EINVAL
}
@@ -176,6 +183,53 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
return r.newSet(ctx, key, owner, owner, perms, nsems)
}
+// IPCInfo returns information about system-wide semaphore limits and parameters.
+func (r *Registry) IPCInfo() *linux.SemInfo {
+ return &linux.SemInfo{
+ SemMap: linux.SEMMAP,
+ SemMni: linux.SEMMNI,
+ SemMns: linux.SEMMNS,
+ SemMnu: linux.SEMMNU,
+ SemMsl: linux.SEMMSL,
+ SemOpm: linux.SEMOPM,
+ SemUme: linux.SEMUME,
+ SemUsz: linux.SEMUSZ,
+ SemVmx: linux.SEMVMX,
+ SemAem: linux.SEMAEM,
+ }
+}
+
+// SemInfo returns a seminfo structure containing the same information as
+// for IPC_INFO, except that SemUsz field returns the number of existing
+// semaphore sets, and SemAem field returns the number of existing semaphores.
+func (r *Registry) SemInfo() *linux.SemInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ info := r.IPCInfo()
+ info.SemUsz = uint32(len(r.semaphores))
+ info.SemAem = uint32(r.totalSems())
+
+ return info
+}
+
+// HighestIndex returns the index of the highest used entry in
+// the kernel's array.
+func (r *Registry) HighestIndex() int32 {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // By default, highest used index is 0 even though
+ // there is no semaphroe set.
+ var highestIndex int32
+ for index := range r.indexes {
+ if index > highestIndex {
+ highestIndex = index
+ }
+ }
+ return highestIndex
+}
+
// RemoveID removes set with give 'id' from the registry and marks the set as
// dead. All waiters will be awakened and fail.
func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
@@ -186,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
if set == nil {
return syserror.EINVAL
}
+ index, found := r.findIndexByID(id)
+ if !found {
+ // Inconsistent state.
+ panic(fmt.Sprintf("unable to find an index for ID: %d", id))
+ }
set.mu.Lock()
defer set.mu.Unlock()
@@ -197,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
}
delete(r.semaphores, set.ID)
+ delete(r.indexes, index)
set.destroy()
return nil
}
@@ -220,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File
continue
}
if r.semaphores[id] == nil {
+ index, found := r.findFirstAvailableIndex()
+ if !found {
+ panic("unable to find an available index")
+ }
+ r.indexes[index] = id
r.lastIDUsed = id
r.semaphores[id] = set
set.ID = id
@@ -238,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set {
return r.semaphores[id]
}
+// FindByIndex looks up a set given an index.
+func (r *Registry) FindByIndex(index int32) *Set {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ id, present := r.indexes[index]
+ if !present {
+ return nil
+ }
+ return r.semaphores[id]
+}
+
func (r *Registry) findByKey(key int32) *Set {
for _, v := range r.semaphores {
if v.key == key {
@@ -247,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set {
return nil
}
+func (r *Registry) findIndexByID(id int32) (int32, bool) {
+ for k, v := range r.indexes {
+ if v == id {
+ return k, true
+ }
+ }
+ return 0, false
+}
+
+func (r *Registry) findFirstAvailableIndex() (int32, bool) {
+ for index := int32(0); index < setsMax; index++ {
+ if _, present := r.indexes[index]; !present {
+ return index, true
+ }
+ }
+ return 0, false
+}
+
func (r *Registry) totalSems() int {
totalSems := 0
for _, v := range r.semaphores {
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index 80a592c8f..073e14507 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -6,6 +6,9 @@ package(licenses = ["notice"])
go_template_instance(
name = "shm_refs",
out = "shm_refs.go",
+ consts = {
+ "enableLogging": "true",
+ },
package = "shm",
prefix = "Shm",
template = "//pkg/refsvfs2:refs_template",
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index e8cce37d0..2488ae7d5 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 78f718cfe..884966120 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
index a83ce219c..3fee7aa68 100644
--- a/pkg/sentry/kernel/syslog.go
+++ b/pkg/sentry/kernel/syslog.go
@@ -75,6 +75,12 @@ func (s *syslog) Log() []byte {
"Checking naughty and nice process list...", // Check it up to twice.
"Granting licence to kill(2)...", // British spelling for British movie.
"Letting the watchdogs out...",
+ "Conjuring /dev/null black hole...",
+ "Adversarially training Redcode AI...",
+ "Singleplexing /dev/ptmx...",
+ "Recruiting cron-ies...",
+ "Verifying that no non-zero bytes made their way into /dev/zero...",
+ "Accelerating teletypewriter to 9600 baud...",
}
selectMessage := func() string {
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c5137c282..16986244c 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -368,8 +368,8 @@ func (t *Task) exitChildren() {
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- siginfo.SetPid(int32(c.tg.pidns.tids[t]))
- siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
+ siginfo.SetPID(int32(c.tg.pidns.tids[t]))
+ siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
c.tg.signalHandlers.mu.Lock()
c.sendSignalLocked(siginfo, true /* group */)
c.tg.signalHandlers.mu.Unlock()
@@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si
info := &arch.SignalInfo{
Signo: int32(sig),
}
- info.SetPid(int32(receiver.tg.pidns.tids[t]))
- info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.tids[t]))
+ info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
if t.exitStatus.Signaled() {
info.Code = arch.CLD_KILLED
info.SetStatus(int32(t.exitStatus.Signo))
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 42dd3e278..75af3af79 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) {
Signo: int32(linux.SIGCHLD),
Code: code,
}
- sigchld.SetPid(int32(t.tg.pidns.tids[target]))
- sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ sigchld.SetPID(int32(t.tg.pidns.tids[target]))
+ sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
sigchld.SetStatus(status)
// TODO(b/72102453): Set utime, stime.
t.sendSignalLocked(sigchld, true /* group */)
@@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState {
Signo: int32(sig),
Code: t.ptraceCode,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
} else {
t.ptraceCode = int32(sig)
t.ptraceSiginfo = nil
@@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
if parent == nil {
// Tracer has detached and t was created by Kernel.CreateProcess().
// Pretend the parent is in an ancestor PID + user namespace.
- info.SetPid(0)
- info.SetUid(int32(auth.OverflowUID))
+ info.SetPID(0)
+ info.SetUID(int32(auth.OverflowUID))
} else {
- info.SetPid(int32(t.tg.pidns.tids[parent]))
- info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(t.tg.pidns.tids[parent]))
+ info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
}
}
t.tg.signalHandlers.mu.Lock()
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 7fd77925f..49e21026e 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
// Translations must be contiguous and in increasing order of
// Translation.Source.
if i > 0 && ts[i-1].Source.End != t.Source.Start {
- return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t)
+ return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t)
}
// At least part of each Translation must be required.
if t.Source.Intersect(required).Length() == 0 {
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 4c8cd38ed..5ab2ef79f 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -36,12 +36,12 @@ type aioManager struct {
contexts map[uint64]*AIOContext
}
-func (a *aioManager) destroy() {
- a.mu.Lock()
- defer a.mu.Unlock()
+func (mm *MemoryManager) destroyAIOManager(ctx context.Context) {
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
- for _, ctx := range a.contexts {
- ctx.destroy()
+ for id := range mm.aioManager.contexts {
+ mm.destroyAIOContextLocked(ctx, id)
}
}
@@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
// be drained.
//
// Nil is returned if the context does not exist.
-func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
- a.mu.Lock()
- defer a.mu.Unlock()
- ctx, ok := a.contexts[id]
+//
+// Precondition: mm.aioManager.mu is locked.
+func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext {
+ aioCtx, ok := mm.aioManager.contexts[id]
if !ok {
return nil
}
- delete(a.contexts, id)
- ctx.destroy()
- return ctx
+
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
+
+ delete(mm.aioManager.contexts, id)
+ aioCtx.destroy()
+ return aioCtx
}
// lookupAIOContext looks up the given context.
@@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() {
}
}
-// Prepare reserves space for a new request, returning true if available.
-// Returns false if the context is busy.
-func (ctx *AIOContext) Prepare() bool {
+// Prepare reserves space for a new request, returning nil if available.
+// Returns EAGAIN if the context is busy and EINVAL if the context is dead.
+func (ctx *AIOContext) Prepare() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
+ if ctx.dead {
+ // Context died after the caller looked it up.
+ return syserror.EINVAL
+ }
if ctx.outstanding >= ctx.maxOutstanding {
- return false
+ // Context is busy.
+ return syserror.EAGAIN
}
ctx.outstanding++
- return true
+ return nil
}
// PopRequest pops a completed request if available, this function does not do
@@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
// DestroyAIOContext destroys an asynchronous I/O context. It returns the
// destroyed context. nil if the context does not exist.
func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
- if _, ok := mm.LookupAIOContext(ctx, id); !ok {
+ if !mm.isValidAddr(ctx, id) {
return nil
}
- // Only unmaps after it assured that the address is a valid aio context to
- // prevent random memory from been unmapped.
- //
- // Note: It's possible to unmap this address and map something else into
- // the same address. Then it would be unmapping memory that it doesn't own.
- // This is, however, the way Linux implements AIO. Keeps the same [weird]
- // semantics in case anyone relies on it.
- mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
-
- return mm.aioManager.destroyAIOContext(id)
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
+ return mm.destroyAIOContextLocked(ctx, id)
}
// LookupAIOContext looks up the given context. It returns false if the context
@@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC
return nil, false
}
- // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes
- // from id).
- var buf [4]byte
- _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
- if err != nil {
+ // Protect against 'id' that is inaccessible.
+ if !mm.isValidAddr(ctx, id) {
return nil, false
}
return aioCtx, true
}
+
+// isValidAddr determines if the address `id` is valid. (Linux also reads 4
+// bytes from id).
+func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool {
+ var buf [4]byte
+ _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
+ return err == nil
+}
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
index 3dabac1af..e8931922f 100644
--- a/pkg/sentry/mm/aio_context_state.go
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -15,6 +15,6 @@
package mm
// afterLoad is invoked by stateify.
-func (a *AIOContext) afterLoad() {
- a.requestReady = make(chan struct{}, 1)
+func (ctx *AIOContext) afterLoad() {
+ ctx.requestReady = make(chan struct{}, 1)
}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 09dbc06a4..120707429 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) {
panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users))
}
- mm.aioManager.destroy()
+ mm.destroyAIOManager(ctx)
mm.metadataMu.Lock()
exe := mm.executable
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index acac3d357..bc53bd41e 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) {
t.Errorf("CopyOut got %d want 1", n)
}
}
+
+// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be
+// prepared after destruction.
+func TestAIOPrepareAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ aioCtx, ok := mm.LookupAIOContext(ctx, id)
+ if !ok {
+ t.Fatalf("AIOContext not found")
+ }
+ mm.DestroyAIOContext(ctx, id)
+
+ // Prepare should fail because aioCtx should be destroyed.
+ if err := aioCtx.Prepare(); err != syserror.EINVAL {
+ t.Errorf("aioCtx.Prepare got err %v want nil", err)
+ } else if err == nil {
+ aioCtx.CancelPendingRequest()
+ }
+}
+
+// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be
+// looked up after memory manager is destroyed.
+func TestAIOLookupAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ mm.DecUsers(ctx)
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ mm.DecUsers(ctx) // This destroys the AIOContext manager.
+
+ if _, ok := mm.LookupAIOContext(ctx, id); ok {
+ t.Errorf("AIOContext found even after AIOContext manager is destroyed")
+ }
+}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 7c297fb9e..d99be7f46 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File
}
if f.opts.ManualZeroing {
- if err := f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
- }
- }); err != nil {
+ if err := f.manuallyZero(fr); err != nil {
return memmap.FileRange{}, err
}
}
@@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
panic(fmt.Sprintf("invalid range: %v", fr))
}
+ if f.opts.ManualZeroing {
+ // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in
+ // effect.
+ if err := f.manuallyZero(fr); err != nil {
+ return err
+ }
+ } else {
+ if err := f.decommitFile(fr); err != nil {
+ return err
+ }
+ }
+
+ f.markDecommitted(fr)
+ return nil
+}
+
+func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error {
+ return f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ })
+}
+
+func (f *MemoryFile) decommitFile(fr memmap.FileRange) error {
// "After a successful call, subsequent reads from this range will
// return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with
// FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2)
- err := syscall.Fallocate(
+ return syscall.Fallocate(
int(f.file.Fd()),
_FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE,
int64(fr.Start),
int64(fr.Length()))
- if err != nil {
- return err
- }
- f.markDecommitted(fr)
- return nil
}
func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
@@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() {
break
}
- if err := f.Decommit(fr); err != nil {
- log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
- // Zero the pages manually. This won't reduce memory usage, but at
- // least ensures that the pages will be zero when reallocated.
- f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
+ // If ManualZeroing is in effect, pages will be zeroed on allocation
+ // and may not be freed by decommitFile, so calling decommitFile is
+ // unnecessary.
+ if !f.opts.ManualZeroing {
+ if err := f.decommitFile(fr); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
+ // Zero the pages manually. This won't reduce memory usage, but at
+ // least ensures that the pages will be zero when reallocated.
+ if err := f.manuallyZero(fr); err != nil {
+ panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err))
}
- })
- // Pretend the pages were decommitted even though they weren't,
- // since the memory accounting implementation has no idea how to
- // deal with this.
- f.markDecommitted(fr)
+ }
}
+ f.markDecommitted(fr)
f.markReclaimed(fr)
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index acad4c793..f8ccb7430 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -91,6 +91,13 @@ func bluepillSigBus(c *vCPU) {
}
}
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ throw("run failed: ENOSYS")
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
//
//go:nosplit
@@ -126,3 +133,10 @@ func bluepillReadyStopGuest(c *vCPU) bool {
}
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ c.die(bluepillArchContext(context), "unknown")
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 965ad66b5..1f09813ba 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -42,6 +42,13 @@ var (
sErrEsr: _ESR_ELx_SERR_NMI,
},
}
+
+ // vcpuExtDabt is the event of ext_dabt.
+ vcpuExtDabt = kvmVcpuEvents{
+ exception: exception{
+ extDabtPending: 1,
+ },
+ }
)
// getTLS returns the value of TPIDR_EL0 register.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 9433d4da5..4d912769a 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -85,7 +85,7 @@ func bluepillStopGuest(c *vCPU) {
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
- throw("sErr injection failed")
+ throw("bounce sErr injection failed")
}
}
@@ -93,18 +93,54 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillSigBus(c *vCPU) {
+ // Host must support ARM64_HAS_RAS_EXTN.
if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
- throw("sErr injection failed")
+ if errno == syscall.EINVAL {
+ throw("No ARM64_HAS_RAS_EXTN feature in host.")
+ }
+ throw("nmi sErr injection failed")
}
}
+// bluepillExtDabt is reponsible for injecting external data abort.
+//
+//go:nosplit
+func bluepillExtDabt(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 {
+ throw("ext_dabt injection failed")
+ }
+}
+
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ bluepillExtDabt(c)
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection.
//
//go:nosplit
func bluepillReadyStopGuest(c *vCPU) bool {
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ switch c.runData.exitReason {
+ case _KVM_EXIT_ARM_NISV:
+ bluepillExtDabt(c)
+ default:
+ c.die(bluepillArchContext(context), "unknown")
+ }
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 75085ac6a..8c5369377 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -148,6 +148,9 @@ func bluepillHandler(context unsafe.Pointer) {
// mode and have interrupts disabled.
bluepillSigBus(c)
continue // Rerun vCPU.
+ case syscall.ENOSYS:
+ bluepillHandleEnosys(c)
+ continue
default:
throw("run failed")
}
@@ -220,7 +223,7 @@ func bluepillHandler(context unsafe.Pointer) {
c.die(bluepillArchContext(context), "entry failed")
return
default:
- c.die(bluepillArchContext(context), "unknown")
+ bluepillArchHandleExit(c, context)
return
}
}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 0b06a923a..9db1db4e9 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -47,10 +47,11 @@ type userRegs struct {
}
type exception struct {
- sErrPending uint8
- sErrHasEsr uint8
- pad [6]uint8
- sErrEsr uint64
+ sErrPending uint8
+ sErrHasEsr uint8
+ extDabtPending uint8
+ pad [5]uint8
+ sErrEsr uint64
}
type kvmVcpuEvents struct {
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 6abaa21c4..2492d57be 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -56,6 +56,7 @@ const (
_KVM_EXIT_FAIL_ENTRY = 0x9
_KVM_EXIT_INTERNAL_ERROR = 0x11
_KVM_EXIT_SYSTEM_EVENT = 0x18
+ _KVM_EXIT_ARM_NISV = 0x1c
)
// KVM capability options.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index fd92c3873..3f5be276b 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -263,13 +263,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return usermem.NoAccess, platform.ErrContextInterrupt
case ring0.El0SyncUndef:
return c.fault(int32(syscall.SIGILL), info)
- case ring0.El1SyncUndef:
- *info = arch.SignalInfo{
- Signo: int32(syscall.SIGILL),
- Code: 1, // ILL_ILLOPC (illegal opcode).
- }
- info.SetAddr(switchOpts.Registers.Pc) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
default:
panic(fmt.Sprintf("unexpected vector: 0x%x", vector))
}
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index f56aa3b79..571bfcc2e 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -18,8 +18,8 @@
//
// In a nutshell, it works as follows:
//
-// The creation of a new address space creates a new child processes with a
-// single thread which is traced by a single goroutine.
+// The creation of a new address space creates a new child process with a single
+// thread which is traced by a single goroutine.
//
// A context is just a collection of temporary variables. Calling Switch on a
// context does the following:
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 812ab80ef..aacd7ce70 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// facilitate vsyscall emulation. See patchSignalInfo.
patchSignalInfo(regs, &c.signalInfo)
return false
- } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) {
+ } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) {
// The signal was generated by this process. That means
// that it was an interrupt or something else that we
// should bail for. Note that we ignore signals
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 679b287c3..2852b7387 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "arch_genrule", "go_library")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
@@ -39,19 +39,19 @@ go_template_instance(
template = ":defs_arm64",
)
-genrule(
+arch_genrule(
name = "entry_impl_amd64",
srcs = ["entry_amd64.s"],
outs = ["entry_impl_amd64.s"],
- cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
-genrule(
+arch_genrule(
name = "entry_impl_arm64",
srcs = ["entry_arm64.s"],
outs = ["entry_impl_arm64.s"],
- cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
@@ -72,7 +72,6 @@ go_library(
"lib_amd64.s",
"lib_arm64.go",
"lib_arm64.s",
- "lib_arm64_unsafe.go",
"ring0.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index 327d48465..c51df2811 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -90,6 +90,7 @@ const (
El0SyncIa
El0SyncFpsimdAcc
El0SyncSveAcc
+ El0SyncFpsimdExc
El0SyncSys
El0SyncSpPc
El0SyncUndef
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index f489ad352..cf0bf3528 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -132,40 +132,6 @@
MOVD offset+PTRACE_R29(reg), R29; \
MOVD offset+PTRACE_R30(reg), R30;
-// NOP-s
-#define nop31Instructions() \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f;
-
#define ESR_ELx_EC_UNKNOWN (0x00)
#define ESR_ELx_EC_WFx (0x01)
/* Unallocated EC: 0x02 */
@@ -324,6 +290,18 @@
MOVD CPU_TTBR0_KVM(from), RSV_REG; \
MSR RSV_REG, TTBR0_EL1;
+TEXT ·EnableVFP(SB),NOSPLIT,$0
+ MOVD $FPEN_ENABLE, R0
+ WORD $0xd5181040 //MSR R0, CPACR_EL1
+ ISB $15
+ RET
+
+TEXT ·DisableVFP(SB),NOSPLIT,$0
+ MOVD $0, R0
+ WORD $0xd5181040 //MSR R0, CPACR_EL1
+ ISB $15
+ RET
+
#define VFP_ENABLE \
MOVD $FPEN_ENABLE, R0; \
WORD $0xd5181040; \ //MSR R0, CPACR_EL1
@@ -370,12 +348,12 @@
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
-// EXCEPTION_WITH_ERROR is a common exception handler function.
-#define EXCEPTION_WITH_ERROR(user, vector) \
+// EXCEPTION_EL0 is a common el0 exception handler function.
+#define EXCEPTION_EL0(vector) \
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
WORD $0xd538601a; \ //MRS FAR_EL1, R26
MOVD R26, CPU_FAULT_ADDR(RSV_REG); \
- MOVD $user, R3; \
+ MOVD $1, R3; \
MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user.
MOVD $vector, R3; \
MOVD R3, CPU_VECTOR_CODE(RSV_REG); \
@@ -383,6 +361,12 @@
MOVD R3, CPU_ERROR_CODE(RSV_REG); \
B ·kernelExitToEl1(SB);
+// EXCEPTION_EL1 is a common el1 exception handler function.
+#define EXCEPTION_EL1(vector) \
+ MOVD $vector, R3; \
+ MOVD R3, 8(RSP); \
+ B ·HaltEl1ExceptionAndResume(SB);
+
// storeAppASID writes the application's asid value.
TEXT ·storeAppASID(SB),NOSPLIT,$0-8
MOVD asid+0(FP), R1
@@ -430,6 +414,16 @@ TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0
CALL ·kernelSyscall(SB) // Call the trampoline.
B ·kernelExitToEl1(SB) // Resume.
+// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume.
+TEXT ·HaltEl1ExceptionAndResume(SB),NOSPLIT,$0-8
+ WORD $0xd538d092 // MRS TPIDR_EL1, R18
+ MOVD CPU_SELF(RSV_REG), R3 // Load vCPU.
+ MOVD R3, 8(RSP) // First argument (vCPU).
+ MOVD vector+0(FP), R3
+ MOVD R3, 16(RSP) // Second argument (vector).
+ CALL ·kernelException(SB) // Call the trampoline.
+ B ·kernelExitToEl1(SB) // Resume.
+
// Shutdown stops the guest.
TEXT ·Shutdown(SB),NOSPLIT,$0
// PSCI EVENT.
@@ -592,39 +586,22 @@ TEXT ·El1_sync(SB),NOSPLIT,$0
B el1_invalid
el1_da:
+ EXCEPTION_EL1(El1SyncDa)
el1_ia:
- WORD $0xd538d092 //MRS TPIDR_EL1, R18
- WORD $0xd538601a //MRS FAR_EL1, R26
-
- MOVD R26, CPU_FAULT_ADDR(RSV_REG)
-
- MOVD $0, CPU_ERROR_TYPE(RSV_REG)
-
- MOVD $PageFault, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- B ·HaltAndResume(SB)
-
+ EXCEPTION_EL1(El1SyncIa)
el1_sp_pc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncSpPc)
el1_undef:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncUndef)
el1_svc:
- MOVD $0, CPU_ERROR_CODE(RSV_REG)
- MOVD $0, CPU_ERROR_TYPE(RSV_REG)
B ·HaltEl1SvcAndResume(SB)
-
el1_dbg:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncDbg)
el1_fpsimd_acc:
VFP_ENABLE
B ·kernelExitToEl1(SB) // Resume.
-
el1_invalid:
- B ·Shutdown(SB)
+ EXCEPTION_EL1(El1SyncInv)
// El1_irq is the handler for El1_irq.
TEXT ·El1_irq(SB),NOSPLIT,$0
@@ -680,28 +657,21 @@ el0_svc:
el0_da:
el0_ia:
- EXCEPTION_WITH_ERROR(1, PageFault)
-
+ EXCEPTION_EL0(PageFault)
el0_fpsimd_acc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncFpsimdAcc)
el0_sve_acc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncSveAcc)
el0_fpsimd_exc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncFpsimdExc)
el0_sp_pc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncSpPc)
el0_undef:
- EXCEPTION_WITH_ERROR(1, El0SyncUndef)
-
+ EXCEPTION_EL0(El0SyncUndef)
el0_dbg:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncDbg)
el0_invalid:
- B ·Shutdown(SB)
+ EXCEPTION_EL0(El0SyncInv)
TEXT ·El0_irq(SB),NOSPLIT,$0
B ·Shutdown(SB)
@@ -760,79 +730,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
// Vectors implements exception vector table.
+// The start address of exception vector table should be 11-bits aligned.
+// For detail, please refer to arm developer document:
+// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table
+// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S
TEXT ·Vectors(SB),NOSPLIT,$0
+ PCALIGN $2048
B ·El1_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error_invalid(SB)
- nop31Instructions()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
- // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- WORD $0xd503201f //nop
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 9742308d8..a9703baf6 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,6 +24,9 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
+ # Use the libc malloc to avoid any extra dependencies. This is required to
+ # pass the sentry deps test.
+ system_malloc = True,
visibility = [
"//pkg/sentry/platform/kvm:__pkg__",
"//pkg/sentry/platform/ring0:__pkg__",
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index c1c808b96..c05284641 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -24,6 +24,10 @@ func HaltAndResume()
//go:nosplit
func HaltEl1SvcAndResume()
+// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume.
+//go:nosplit
+func HaltEl1ExceptionAndResume()
+
// init initializes architecture-specific state.
func (k *Kernel) init(maxCPUs int) {
}
@@ -49,6 +53,12 @@ func IsCanonical(addr uint64) bool {
return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000
}
+// SwitchToUser performs an eret.
+//
+// The return value is the exception vector.
+//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeAppASID(uintptr(switchOpts.UserASID))
@@ -61,11 +71,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate &= ^uint64(PsrFlagsClear)
regs.Pstate |= UserFlagsSet
+ EnableVFP()
LoadFloatingPoint(switchOpts.FloatingPointState)
kernelExitToEl0()
SaveFloatingPoint(switchOpts.FloatingPointState)
+ DisableVFP()
vector = c.vecCode
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index bf1c655f4..a490bf3af 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -34,13 +34,13 @@ func FlushTlbAll()
// CPACREL1 returns the value of the CPACR_EL1 register.
func CPACREL1() (value uintptr)
-// FPCR returns the value of FPCR register.
+// GetFPCR returns the value of FPCR register.
func GetFPCR() (value uintptr)
// SetFPCR writes the FPCR value.
func SetFPCR(value uintptr)
-// FPSR returns the value of FPSR register.
+// GetFPSR returns the value of FPSR register.
func GetFPSR() (value uintptr)
// SetFPSR writes the FPSR value.
@@ -59,9 +59,13 @@ func LoadFloatingPoint(*byte)
// SaveFloatingPoint saves floating point state.
func SaveFloatingPoint(*byte)
+// EnableVFP enables fpsimd.
+func EnableVFP()
+
+// DisableVFP disables fpsimd.
+func DisableVFP()
+
// Init sets function pointers based on architectural features.
//
// This must be called prior to using ring0.
-func Init() {
- rewriteVectors()
-}
+func Init() {}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 675a8bdb7..e39b32841 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -52,62 +52,47 @@ TEXT ·CPACREL1(SB),NOSPLIT,$0-8
RET
TEXT ·GetFPCR(SB),NOSPLIT,$0-8
- WORD $0xd53b4201 // MRS NZCV, R1
+ MOVD FPCR, R1
MOVD R1, ret+0(FP)
RET
TEXT ·GetFPSR(SB),NOSPLIT,$0-8
- WORD $0xd53b4421 // MRS FPSR, R1
+ MOVD FPSR, R1
MOVD R1, ret+0(FP)
RET
TEXT ·SetFPCR(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R1
- WORD $0xd51b4201 // MSR R1, NZCV
+ MOVD R1, FPCR
RET
TEXT ·SetFPSR(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R1
- WORD $0xd51b4421 // MSR R1, FPSR
+ MOVD R1, FPSR
RET
TEXT ·SaveVRegs(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R0
// Skip aarch64_ctx, fpsr, fpcr.
- FMOVD F0, 16*1(R0)
- FMOVD F1, 16*2(R0)
- FMOVD F2, 16*3(R0)
- FMOVD F3, 16*4(R0)
- FMOVD F4, 16*5(R0)
- FMOVD F5, 16*6(R0)
- FMOVD F6, 16*7(R0)
- FMOVD F7, 16*8(R0)
- FMOVD F8, 16*9(R0)
- FMOVD F9, 16*10(R0)
- FMOVD F10, 16*11(R0)
- FMOVD F11, 16*12(R0)
- FMOVD F12, 16*13(R0)
- FMOVD F13, 16*14(R0)
- FMOVD F14, 16*15(R0)
- FMOVD F15, 16*16(R0)
- FMOVD F16, 16*17(R0)
- FMOVD F17, 16*18(R0)
- FMOVD F18, 16*19(R0)
- FMOVD F19, 16*20(R0)
- FMOVD F20, 16*21(R0)
- FMOVD F21, 16*22(R0)
- FMOVD F22, 16*23(R0)
- FMOVD F23, 16*24(R0)
- FMOVD F24, 16*25(R0)
- FMOVD F25, 16*26(R0)
- FMOVD F26, 16*27(R0)
- FMOVD F27, 16*28(R0)
- FMOVD F28, 16*29(R0)
- FMOVD F29, 16*30(R0)
- FMOVD F30, 16*31(R0)
- FMOVD F31, 16*32(R0)
- ISB $15
+ ADD $16, R0, R0
+
+ WORD $0xad000400 // stp q0, q1, [x0]
+ WORD $0xad010c02 // stp q2, q3, [x0, #32]
+ WORD $0xad021404 // stp q4, q5, [x0, #64]
+ WORD $0xad031c06 // stp q6, q7, [x0, #96]
+ WORD $0xad042408 // stp q8, q9, [x0, #128]
+ WORD $0xad052c0a // stp q10, q11, [x0, #160]
+ WORD $0xad06340c // stp q12, q13, [x0, #192]
+ WORD $0xad073c0e // stp q14, q15, [x0, #224]
+ WORD $0xad084410 // stp q16, q17, [x0, #256]
+ WORD $0xad094c12 // stp q18, q19, [x0, #288]
+ WORD $0xad0a5414 // stp q20, q21, [x0, #320]
+ WORD $0xad0b5c16 // stp q22, q23, [x0, #352]
+ WORD $0xad0c6418 // stp q24, q25, [x0, #384]
+ WORD $0xad0d6c1a // stp q26, q27, [x0, #416]
+ WORD $0xad0e741c // stp q28, q29, [x0, #448]
+ WORD $0xad0f7c1e // stp q30, q31, [x0, #480]
RET
@@ -115,39 +100,24 @@ TEXT ·LoadVRegs(SB),NOSPLIT,$0-8
MOVD addr+0(FP), R0
// Skip aarch64_ctx, fpsr, fpcr.
- FMOVD 16*1(R0), F0
- FMOVD 16*2(R0), F1
- FMOVD 16*3(R0), F2
- FMOVD 16*4(R0), F3
- FMOVD 16*5(R0), F4
- FMOVD 16*6(R0), F5
- FMOVD 16*7(R0), F6
- FMOVD 16*8(R0), F7
- FMOVD 16*9(R0), F8
- FMOVD 16*10(R0), F9
- FMOVD 16*11(R0), F10
- FMOVD 16*12(R0), F11
- FMOVD 16*13(R0), F12
- FMOVD 16*14(R0), F13
- FMOVD 16*15(R0), F14
- FMOVD 16*16(R0), F15
- FMOVD 16*17(R0), F16
- FMOVD 16*18(R0), F17
- FMOVD 16*19(R0), F18
- FMOVD 16*20(R0), F19
- FMOVD 16*21(R0), F20
- FMOVD 16*22(R0), F21
- FMOVD 16*23(R0), F22
- FMOVD 16*24(R0), F23
- FMOVD 16*25(R0), F24
- FMOVD 16*26(R0), F25
- FMOVD 16*27(R0), F26
- FMOVD 16*28(R0), F27
- FMOVD 16*29(R0), F28
- FMOVD 16*30(R0), F29
- FMOVD 16*31(R0), F30
- FMOVD 16*32(R0), F31
- ISB $15
+ ADD $16, R0, R0
+
+ WORD $0xad400400 // ldp q0, q1, [x0]
+ WORD $0xad410c02 // ldp q2, q3, [x0, #32]
+ WORD $0xad421404 // ldp q4, q5, [x0, #64]
+ WORD $0xad431c06 // ldp q6, q7, [x0, #96]
+ WORD $0xad442408 // ldp q8, q9, [x0, #128]
+ WORD $0xad452c0a // ldp q10, q11, [x0, #160]
+ WORD $0xad46340c // ldp q12, q13, [x0, #192]
+ WORD $0xad473c0e // ldp q14, q15, [x0, #224]
+ WORD $0xad484410 // ldp q16, q17, [x0, #256]
+ WORD $0xad494c12 // ldp q18, q19, [x0, #288]
+ WORD $0xad4a5414 // ldp q20, q21, [x0, #320]
+ WORD $0xad4b5c16 // ldp q22, q23, [x0, #352]
+ WORD $0xad4c6418 // ldp q24, q25, [x0, #384]
+ WORD $0xad4d6c1a // ldp q26, q27, [x0, #416]
+ WORD $0xad4e741c // ldp q28, q29, [x0, #448]
+ WORD $0xad4f7c1e // ldp q30, q31, [x0, #480]
RET
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
deleted file mode 100644
index c05166fea..000000000
--- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build arm64
-
-package ring0
-
-import (
- "reflect"
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/safecopy"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-const (
- nopInstruction = 0xd503201f
- instSize = unsafe.Sizeof(uint32(0))
- vectorsRawLen = 0x800
-)
-
-func unsafeSlice(addr uintptr, length int) (slice []uint32) {
- hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
- hdr.Data = addr
- hdr.Len = length / int(instSize)
- hdr.Cap = length / int(instSize)
- return slice
-}
-
-// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
-//
-// According to the design documentation of Arm64,
-// the start address of exception vector table should be 11-bits aligned.
-// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
-// But, we can't align a function's start address to a specific address by using golang.
-// We have raised this question in golang community:
-// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
-// This function will be removed when golang supports this feature.
-//
-// There are 2 jobs were implemented in this function:
-// 1, move the start address of exception vector table into the specific address.
-// 2, modify the offset of each instruction.
-func rewriteVectors() {
- vectorsBegin := reflect.ValueOf(Vectors).Pointer()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // And the size is 0x800.
- // Please see the documentation as reference:
- // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
- //
- // But, golang does not allow to set a function's address to a specific value.
- // So, for gvisor, I defined the size of exception-vector-table as 4K,
- // filled the 2nd 2K part with NOP-s.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- //
- // So, the prerequisite for this function to work correctly is:
- // vectorsSafeLen >= 0x1000
- // vectorsRawLen = 0x800
- vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
- if vectorsSafeLen < 2*vectorsRawLen {
- panic("Can't update vectors")
- }
-
- vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
- vectorsRawLen32 := vectorsRawLen / int(instSize)
-
- offset := vectorsBegin & (1<<11 - 1)
- if offset != 0 {
- offset = 1<<11 - offset
- }
-
- pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
-
- _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-
- offset = offset / instSize // By index, not bytes.
- // Move exception-vector-table into the specific address, should uses memmove here.
- for i := 1; i <= vectorsRawLen32; i++ {
- vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
- }
-
- // Adjust branch since instruction was moved forward.
- for i := 0; i < vectorsRawLen32; i++ {
- if vectorsSafeTable[int(offset)+i] != nopInstruction {
- vectorsSafeTable[int(offset)+i] -= uint32(offset)
- }
- }
-
- _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-}
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 53bc3353c..b5652deb9 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -70,6 +70,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa)
fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc)
fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc)
+ fmt.Fprintf(w, "#define El0SyncFpsimdExc 0x%02x\n", El0SyncFpsimdExc)
fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys)
fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc)
fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef)
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index bc16a1622..7605d0cb2 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -58,6 +58,15 @@ type PageTables struct {
readOnlyShared bool
}
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+}
+
// NewWithUpper returns new PageTables.
//
// upperSharedPageTables are used for mapping the upper of addresses,
@@ -73,14 +82,17 @@ type PageTables struct {
func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables {
p := new(PageTables)
p.Init(a)
+
if upperSharedPageTables != nil {
if !upperSharedPageTables.readOnlyShared {
panic("Only read-only shared pagetables can be used as upper")
}
p.upperSharedPageTables = upperSharedPageTables
p.upperStart = upperStart
- p.cloneUpperShared()
}
+
+ p.InitArch(a)
+
return p
}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index a4e416af7..520161755 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -24,6 +24,14 @@ import (
// archPageTables is architecture-specific data.
type archPageTables struct {
+ // root is the pagetable root for kernel space.
+ root *PTEs
+
+ // rootPhysical is the cached physical address of the root.
+ //
+ // This is saved only to prevent constant translation.
+ rootPhysical uintptr
+
asid uint16
}
@@ -38,7 +46,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
//
//go:nosplit
func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
- return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+ return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
}
// Bits in page table entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
index e7ab887e5..4bdde8448 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -41,13 +41,13 @@ const (
entriesPerPage = 512
)
-// Init initializes a set of PageTables.
+// InitArch does some additional initialization related to the architecture.
//
//go:nosplit
-func (p *PageTables) Init(allocator Allocator) {
- p.Allocator = allocator
- p.root = p.Allocator.NewPTEs()
- p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+func (p *PageTables) InitArch(allocator Allocator) {
+ if p.upperSharedPageTables != nil {
+ p.cloneUpperShared()
+ }
}
func pgdIndex(upperStart uintptr) uintptr {
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
index 5392bf27a..ad0e30c88 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -42,13 +42,16 @@ const (
entriesPerPage = 512
)
-// Init initializes a set of PageTables.
+// InitArch does some additional initialization related to the architecture.
//
//go:nosplit
-func (p *PageTables) Init(allocator Allocator) {
- p.Allocator = allocator
- p.root = p.Allocator.NewPTEs()
- p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+func (p *PageTables) InitArch(allocator Allocator) {
+ if p.upperSharedPageTables != nil {
+ p.cloneUpperShared()
+ } else {
+ p.archPageTables.root = p.Allocator.NewPTEs()
+ p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+ }
}
// cloneUpperShared clone the upper from the upper shared page tables.
@@ -59,7 +62,8 @@ func (p *PageTables) cloneUpperShared() {
panic("upperStart should be the same as upperBottom")
}
- // nothing to do for arm.
+ p.archPageTables.root = p.upperSharedPageTables.archPageTables.root
+ p.archPageTables.rootPhysical = p.upperSharedPageTables.archPageTables.rootPhysical
}
// PTEs is a collection of entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
index 157c9a7cc..c261d393a 100644
--- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr {
func (w *Walker) iterateRangeCanonical(start, end uintptr) {
pgdEntryIndex := w.pageTables.root
if start >= upperBottom {
- pgdEntryIndex = w.pageTables.upperSharedPageTables.root
+ pgdEntryIndex = w.pageTables.archPageTables.root
}
for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index a3f775d15..cc1f6bfcc 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index ca16d0381..fb7c5dc61 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -23,7 +23,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..b88cdca48 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,21 +343,34 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
)
}
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -384,7 +396,11 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
+ }
+
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
return buf
@@ -416,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ return space
}
// Parse parses a raw socket control message into portable objects.
@@ -489,6 +503,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.Unix.Credentials = scmCreds
i += binary.AlignUp(length, width)
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTimestamp = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp)
+ i += binary.AlignUp(length, width)
+
default:
// Unknown message type.
return socket.ControlMessages{}, syserror.EINVAL
@@ -512,7 +534,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
i += binary.AlignUp(length, width)
default:
@@ -528,6 +559,15 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += binary.AlignUp(length, width)
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..1f220c343 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
case linux.SO_LINGER:
optlen = syscall.SizeofLinger
@@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR:
optlen = sizeofInt32
case linux.IP_PKTINFO:
optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
}
case linux.SOL_TCP:
switch name {
- case linux.TCP_NODELAY:
+ case linux.TCP_NODELAY, linux.TCP_INQ:
optlen = sizeofInt32
}
}
@@ -416,6 +416,37 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
return nil
}
+func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ msg := syscall.Msghdr{}
+ if len(iovs) > 0 {
+ msg.Iov = &iovs[0]
+ msg.Iovlen = uint64(len(iovs))
+ }
+ var senderAddrBuf []byte
+ if senderRequested {
+ senderAddrBuf = make([]byte, sizeofSockaddr)
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(sizeofSockaddr)
+ }
+ var controlBuf []byte
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
+ }
+ return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
+}
+
// RecvMsg implements socket.Socket.RecvMsg.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Only allow known and safe flags.
@@ -427,56 +458,36 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
}
- var senderAddr linux.SockAddr
var senderAddrBuf []byte
- if senderRequested {
- senderAddrBuf = make([]byte, sizeofSockaddr)
- }
-
var controlBuf []byte
var msgFlags int
-
- recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
- // Refuse to do anything if any part of dst.Addrs was unusable.
- if uint64(dst.NumBytes()) != dsts.NumBytes() {
- return 0, nil
- }
- if dsts.IsEmpty() {
- return 0, nil
- }
-
- // We always do a non-blocking recv*().
- sysflags := flags | syscall.MSG_DONTWAIT
-
- iovs := safemem.IovecsFromBlockSeq(dsts)
- msg := syscall.Msghdr{
- Iov: &iovs[0],
- Iovlen: uint64(len(iovs)),
- }
- if len(senderAddrBuf) != 0 {
- msg.Name = &senderAddrBuf[0]
- msg.Namelen = uint32(len(senderAddrBuf))
- }
- if controlLen > 0 {
- if controlLen > maxControlLen {
- controlLen = maxControlLen
+ copyToDst := func() (int64, error) {
+ var n uint64
+ var err error
+ if dst.NumBytes() == 0 {
+ // We want to make the recvmsg(2) call to the host even if dst is empty
+ // to fetch control messages, sender address or errors if any occur.
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
+ return int64(n), err
+ }
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
}
- controlBuf = make([]byte, controlLen)
- msg.Control = &controlBuf[0]
- msg.Controllen = controlLen
- }
- n, err := recvmsg(s.fd, &msg, sysflags)
- if err != nil {
- return 0, err
- }
- senderAddrBuf = senderAddrBuf[:msg.Namelen]
- msgFlags = int(msg.Flags)
- controlLen = uint64(msg.Controllen)
- return n, nil
- })
+
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
+ return n, err
+ })
+ return dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
var ch chan struct{}
- n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err := copyToDst()
if flags&syscall.MSG_DONTWAIT == 0 {
for err == syserror.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
@@ -494,48 +505,75 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
s.EventRegister(&e, waiter.EventIn)
defer s.EventUnregister(&e)
}
- n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err = copyToDst()
}
}
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ var senderAddr linux.SockAddr
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
+}
+func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
controlMessages := socket.ControlMessages{}
for _, unixCmsg := range unixControlMessages {
switch unixCmsg.Header.Level {
- case syscall.SOL_IP:
+ case linux.SOL_SOCKET:
+ switch unixCmsg.Header.Type {
+ case linux.SO_TIMESTAMP:
+ controlMessages.IP.HasTimestamp = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp)
+ }
+
+ case linux.SOL_IP:
switch unixCmsg.Header.Type {
- case syscall.IP_TOS:
+ case linux.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
- case syscall.IP_PKTINFO:
+ case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ controlMessages.IP.PacketInfo = packetInfo
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
}
- case syscall.SOL_IPV6:
+ case linux.SOL_IPV6:
switch unixCmsg.Header.Type {
- case syscall.IPV6_TCLASS:
+ case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+ }
+
+ case linux.SOL_TCP:
+ switch unixCmsg.Header.Type {
+ case linux.TCP_INQ:
+ controlMessages.IP.HasInq = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq)
}
}
}
-
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+ return controlMessages
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 3baad098b..057f4d294 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -120,9 +120,6 @@ type socketOpsCommon struct {
// fixed buffer but only consume this many bytes.
sendBufferSize uint32
- // passcred indicates if this socket wants SCM credentials.
- passcred bool
-
// filter indicates that this socket has a BPF filter "installed".
//
// TODO(gvisor.dev/issue/1119): We don't actually support filtering,
@@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
// Passcred implements transport.Credentialer.Passcred.
func (s *socketOpsCommon) Passcred() bool {
- s.mu.Lock()
- passcred := s.passcred
- s.mu.Unlock()
- return passcred
+ return s.ep.SocketOptions().GetPassCred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
@@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
}
passcred := usermem.ByteOrder.Uint32(opt)
- s.mu.Lock()
- s.passcred = passcred != 0
- s.mu.Unlock()
+ s.ep.SocketOptions().SetPassCred(passcred != 0)
return nil
case linux.SO_ATTACH_FILTER:
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 7d0ae15ca..23d5cab9c 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -84,69 +84,95 @@ var Metrics = tcpip.Stats{
MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
ICMP: tcpip.ICMPStats{
- V4PacketsSent: tcpip.ICMPv4SentPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ V4: tcpip.ICMPv4Stats{
+ PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
},
- V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ V6: tcpip.ICMPv6Stats{
+ PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- V6PacketsSent: tcpip.ICMPv6SentPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ IGMP: tcpip.IGMPStats{
+ PacketsSent: tcpip.IGMPSentPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."),
},
- V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ PacketsReceived: tcpip.IGMPReceivedPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."),
+ ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."),
},
},
IP: tcpip.IPStats{
@@ -209,18 +235,6 @@ const sizeOfInt32 int = 4
var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
-// ntohs converts a 16-bit number from network byte order to host byte order. It
-// assumes that the host is little endian.
-func ntohs(v uint16) uint16 {
- return v<<8 | v>>8
-}
-
-// htons converts a 16-bit number from host byte order to network byte order. It
-// assumes that the host is little endian.
-func htons(v uint16) uint16 {
- return ntohs(v)
-}
-
// commonEndpoint represents the intersection of a tcpip.Endpoint and a
// transport.Endpoint.
type commonEndpoint interface {
@@ -240,10 +254,6 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
- // transport.Endpoint.SetSockOptBool.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -252,18 +262,20 @@ type commonEndpoint interface {
// transport.Endpoint.GetSockOpt.
GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
- // transport.Endpoint.GetSockOpt.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
- // LastError implements tcpip.Endpoint.LastError.
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
+ // LastError implements tcpip.Endpoint.LastError and
+ // transport.Endpoint.LastError.
LastError() *tcpip.Error
- // SocketOptions implements tcpip.Endpoint.SocketOptions.
+ // SocketOptions implements tcpip.Endpoint.SocketOptions and
+ // transport.Endpoint.SocketOptions.
SocketOptions() *tcpip.SocketOptions
}
@@ -308,7 +320,7 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -332,9 +344,7 @@ type socketOpsCommon struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
dirent := socket.NewDirent(t, netstackDevice)
@@ -363,88 +373,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
-// AF_INET6, and AF_PACKET addresses.
-//
-// AddressAndFamily returns an address and its family.
-func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
- // Make sure we have at least 2 bytes for the address family.
- if len(addr) < 2 {
- return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
- }
-
- // Get the rest of the fields based on the address family.
- switch family := usermem.ByteOrder.Uint16(addr); family {
- case linux.AF_UNIX:
- path := addr[2:]
- if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- // Drop the terminating NUL (if one exists) and everything after
- // it for filesystem (non-abstract) addresses.
- if len(path) > 0 && path[0] != 0 {
- if n := bytes.IndexByte(path[1:], 0); n >= 0 {
- path = path[:n+1]
- }
- }
- return tcpip.FullAddress{
- Addr: tcpip.Address(path),
- }, family, nil
-
- case linux.AF_INET:
- var a linux.SockAddrInet
- if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- return out, family, nil
-
- case linux.AF_INET6:
- var a linux.SockAddrInet6
- if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- if isLinkLocal(out.Addr) {
- out.NIC = tcpip.NICID(a.Scope_id)
- }
- return out, family, nil
-
- case linux.AF_PACKET:
- var a linux.SockAddrLink
- if len(addr) < sockAddrLinkSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
- if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
-
- // TODO(gvisor.dev/issue/173): Return protocol too.
- return tcpip.FullAddress{
- NIC: tcpip.NICID(a.InterfaceIndex),
- Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
- }, family, nil
-
- case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, family, nil
-
- default:
- return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
- }
-}
-
func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
@@ -480,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = v
- s.readCM = cms
+ s.readCM = socket.NewIPControlMessages(s.family, cms)
atomic.StoreUint32(&s.readViewHasData, 1)
return nil
@@ -500,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
return
}
- var v tcpip.LingerOption
- if err := s.Endpoint.GetSockOpt(&v); err != nil {
- return
- }
-
+ v := s.Endpoint.SocketOptions().GetLinger()
// The case for zero timeout is handled in tcp endpoint close function.
// Close is blocked until either:
// 1. The endpoint state is not in any of the states: FIN-WAIT1,
@@ -721,11 +645,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
return nil
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
- v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return syserr.TranslateNetstackError(err)
- }
- if !v {
+ if !s.Endpoint.SocketOptions().GetV6Only() {
return nil
}
}
@@ -749,7 +669,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -830,7 +750,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
} else {
var err *syserr.Error
- addr, family, err = AddressAndFamily(sockaddr)
+ addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -921,7 +841,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -1005,7 +925,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
case linux.SOL_TCP:
- return getSockOptTCP(t, ep, name, outLen)
+ return getSockOptTCP(t, s, ep, name, outLen)
case linux.SOL_IPV6:
return getSockOptIPv6(t, s, ep, name, outPtr, outLen)
@@ -1041,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.LastError()
+ err := ep.SocketOptions().GetLastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1068,13 +988,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.PasscredOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred()))
+ return &v, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1115,25 +1030,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress()))
+ return &v, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort()))
+ return &v, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1174,24 +1080,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive()))
+ return &v, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.LingerOption
var linger linux.Linger
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetLinger()
if v.Enabled {
linger.OnOff = 1
@@ -1222,34 +1120,26 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.OutOfBandInlineOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(v)
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline()))
+ return &v, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum()))
+ return &v, nil
case linux.SO_ACCEPTCONN:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.AcceptConnOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ // This option is only viable for TCP endpoints.
+ var v bool
+ if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) {
+ v = tcp.EndpointState(ep.State()) == tcp.StateListen
}
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
@@ -1261,46 +1151,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(!v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption()))
+ return &v, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.CorkOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption()))
+ return &v, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck()))
+ return &v, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1474,19 +1354,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, _ := s.Type()
+ if family != linux.AF_INET6 {
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only()))
+ return &v, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1518,13 +1403,16 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
+ return &v, nil
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1536,7 +1424,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
@@ -1545,7 +1433,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1565,7 +1453,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1585,7 +1473,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1607,6 +1495,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
// getSockOptIP implements GetSockOpt when level is SOL_IP.
func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1649,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
return &a.(*linux.SockAddrInet).Addr, nil
@@ -1658,13 +1551,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop()))
+ return &v, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
@@ -1688,26 +1576,32 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
+ return &v, nil
+
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo()))
+ return &v, nil
- case linux.IP_PKTINFO:
+ case linux.IP_HDRINCL:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
+ return &v, nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
@@ -1719,7 +1613,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet), nil
case linux.IPT_SO_GET_INFO:
@@ -1826,7 +1720,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
return setSockOptSocket(t, s, ep, name, optVal)
case linux.SOL_TCP:
- return setSockOptTCP(t, ep, name, optVal)
+ return setSockOptTCP(t, s, ep, name, optVal)
case linux.SOL_IPV6:
return setSockOptIPv6(t, s, ep, name, optVal)
@@ -1876,7 +1770,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
+ ep.SocketOptions().SetReuseAddress(v != 0)
+ return nil
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1884,7 +1779,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
+ ep.SocketOptions().SetReusePort(v != 0)
+ return nil
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
@@ -1923,7 +1819,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
+ ep.SocketOptions().SetPassCred(v != 0)
+ return nil
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1931,7 +1828,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
+ ep.SocketOptions().SetKeepAlive(v != 0)
+ return nil
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1970,8 +1868,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- opt := tcpip.OutOfBandInlineOption(v)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
+ ep.SocketOptions().SetOutOfBandInline(v != 0)
+ return nil
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1979,7 +1877,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+ ep.SocketOptions().SetNoChecksum(v != 0)
+ return nil
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
@@ -1993,10 +1892,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(
- ep.SetSockOpt(&tcpip.LingerOption{
- Enabled: v.OnOff != 0,
- Timeout: time.Second * time.Duration(v.Linger)}))
+ ep.SocketOptions().SetLinger(tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger),
+ })
+ return nil
case linux.SO_DETACH_FILTER:
// optval is ignored.
@@ -2011,7 +1911,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
-func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if len(optVal) < sizeOfInt32 {
@@ -2019,7 +1924,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
+ ep.SocketOptions().SetDelayOption(v == 0)
+ return nil
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -2027,7 +1933,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
+ ep.SocketOptions().SetCorkOption(v != 0)
+ return nil
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -2035,7 +1942,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
+ ep.SocketOptions().SetQuickAck(v != 0)
+ return nil
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -2147,14 +2055,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, skProto := s.Type()
+ if family != linux.AF_INET6 {
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
+ if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ }
+
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+ ep.SocketOptions().SetV6Only(v != 0)
+ return nil
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
@@ -2174,6 +2099,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2193,7 +2127,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ ep.SocketOptions().SetReceiveTClass(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2201,7 +2136,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return syserr.ErrProtocolNotAvailable
}
@@ -2276,6 +2211,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
// setSockOptIP implements SetSockOpt when level is SOL_IP.
func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_MULTICAST_TTL:
v, err := parseIntOrChar(optVal)
@@ -2328,7 +2268,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{
NIC: tcpip.NICID(req.InterfaceIndex),
- InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]),
}))
case linux.IP_MULTICAST_LOOP:
@@ -2337,7 +2277,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+ ep.SocketOptions().SetMulticastLoop(v != 0)
+ return nil
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -2373,7 +2314,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ ep.SocketOptions().SetReceiveTOS(v != 0)
+ return nil
case linux.IP_PKTINFO:
if len(optVal) == 0 {
@@ -2383,7 +2325,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ ep.SocketOptions().SetReceivePacketInfo(v != 0)
+ return nil
case linux.IP_HDRINCL:
if len(optVal) == 0 {
@@ -2393,7 +2336,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+ ep.SocketOptions().SetHeaderIncluded(v != 0)
+ return nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
@@ -2433,7 +2389,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2511,7 +2466,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2535,7 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
switch name {
case linux.IP_TOS,
linux.IP_TTL,
- linux.IP_HDRINCL,
linux.IP_OPTIONS,
linux.IP_ROUTER_ALERT,
linux.IP_RECVOPTS,
@@ -2582,72 +2535,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
}
}
-// isLinkLocal determines if the given IPv6 address is link-local. This is the
-// case when it has the fe80::/10 prefix. This check is used to determine when
-// the NICID is relevant for a given IPv6 address.
-func isLinkLocal(addr tcpip.Address) bool {
- return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
-}
-
-// ConvertAddress converts the given address to a native format.
-func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
- switch family {
- case linux.AF_UNIX:
- var out linux.SockAddrUnix
- out.Family = linux.AF_UNIX
- l := len([]byte(addr.Addr))
- for i := 0; i < l; i++ {
- out.Path[i] = int8(addr.Addr[i])
- }
-
- // Linux returns the used length of the address struct (including the
- // null terminator) for filesystem paths. The Family field is 2 bytes.
- // It is sometimes allowed to exclude the null terminator if the
- // address length is the max. Abstract and empty paths always return
- // the full exact length.
- if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
- return &out, uint32(2 + l)
- }
- return &out, uint32(3 + l)
-
- case linux.AF_INET:
- var out linux.SockAddrInet
- copy(out.Addr[:], addr.Addr)
- out.Family = linux.AF_INET
- out.Port = htons(addr.Port)
- return &out, uint32(sockAddrInetSize)
-
- case linux.AF_INET6:
- var out linux.SockAddrInet6
- if len(addr.Addr) == header.IPv4AddressSize {
- // Copy address in v4-mapped format.
- copy(out.Addr[12:], addr.Addr)
- out.Addr[10] = 0xff
- out.Addr[11] = 0xff
- } else {
- copy(out.Addr[:], addr.Addr)
- }
- out.Family = linux.AF_INET6
- out.Port = htons(addr.Port)
- if isLinkLocal(addr.Addr) {
- out.Scope_id = uint32(addr.NIC)
- }
- return &out, uint32(sockAddrInet6Size)
-
- case linux.AF_PACKET:
- // TODO(gvisor.dev/issue/173): Return protocol too.
- var out linux.SockAddrLink
- out.Family = linux.AF_PACKET
- out.InterfaceIndex = int32(addr.NIC)
- out.HardwareAddrLen = header.EthernetAddressSize
- copy(out.HardwareAddr[:], addr.Addr)
- return &out, uint32(sockAddrLinkSize)
-
- default:
- return nil, 0
- }
-}
-
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
@@ -2656,7 +2543,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2668,7 +2555,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2686,7 +2573,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
// Always do at least one fetchReadView, even if the number of bytes to
// read is 0.
err = s.fetchReadView()
- if err != nil {
+ if err != nil || len(s.readView) == 0 {
break
}
if dst.NumBytes() == 0 {
@@ -2709,15 +2596,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
}
copied += n
s.readView.TrimFront(n)
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
- }
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
break
}
+ // If we are done reading requested data then stop.
+ if dst.NumBytes() == 0 {
+ break
+ }
+ }
+
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
}
// If we managed to copy something, we must deliver it.
@@ -2812,10 +2704,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
- addr, addrLen = ConvertAddress(s.family, s.sender)
+ addr, addrLen = socket.ConvertAddress(s.family, s.sender)
switch v := addr.(type) {
case *linux.SockAddrLink:
- v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
@@ -2833,7 +2725,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// We need to peek beyond the first message.
dst = dst.DropFirst(n)
num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
+ n, err := s.Endpoint.Peek(dsts)
// TODO(b/78348848): Handle peek timestamp.
if err != nil {
return int64(n), syserr.TranslateNetstackError(err).ToError()
@@ -2877,15 +2769,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
},
}
}
@@ -2980,7 +2873,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, family, err := AddressAndFamily(to)
+ addrBuf, family, err := socket.AddressAndFamily(to)
if err != nil {
return 0, err
}
@@ -3399,6 +3292,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
return rv
}
+func isTCPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP)
+}
+
+func isUDPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP)
+}
+
+func isICMPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6)
+}
+
// State implements socket.Socket.State. State translates the internal state
// returned by netstack to values defined by Linux.
func (s *socketOpsCommon) State() uint32 {
@@ -3408,7 +3313,7 @@ func (s *socketOpsCommon) State() uint32 {
}
switch {
- case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ case isTCPSocket(s.skType, s.protocol):
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -3437,7 +3342,7 @@ func (s *socketOpsCommon) State() uint32 {
// Internal or unknown state.
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ case isUDPSocket(s.skType, s.protocol):
// UDP socket.
switch udp.EndpointState(s.Endpoint.State()) {
case udp.StateInitial, udp.StateBound, udp.StateClosed:
@@ -3447,7 +3352,7 @@ func (s *socketOpsCommon) State() uint32 {
default:
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ case isICMPSocket(s.skType, s.protocol):
// TODO(b/112063468): Export states for ICMP sockets.
case s.skType == linux.SOCK_RAW:
// TODO(b/112063468): Export states for raw sockets.
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index b0d9e4d9e..b756bfca0 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// NewVFS2 creates a new endpoint socket.
func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
mnt := t.Kernel().SocketMount()
@@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addrLen uint32
if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index ead3b2b79..c847ff1c7 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
index 2a01143f6..0af805246 100644
--- a/pkg/sentry/socket/netstack/provider_vfs2.go
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index fa9ac9059..cc0fadeb5 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
- in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
- out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
0, // Icmp/InMsgs.
- Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors.
0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
@@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
0, // Icmp/OutMsgs.
- Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
- out.DstUnreachable.Value(), // OutDestUnreachs.
- out.TimeExceeded.Value(), // OutTimeExcds.
- out.ParamProblem.Value(), // OutParmProbs.
- out.SrcQuench.Value(), // OutSrcQuenchs.
- out.Redirect.Value(), // OutRedirects.
- out.Echo.Value(), // OutEchos.
- out.EchoReply.Value(), // OutEchoReps.
- out.Timestamp.Value(), // OutTimestamps.
- out.TimestampReply.Value(), // OutTimestampReps.
- out.InfoRequest.Value(), // OutAddrMasks.
- out.InfoReply.Value(), // OutAddrMaskReps.
+ Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
}
case *inet.StatSNMPTCP:
tcp := Metrics.TCP
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fd31479e5..bcc426e33 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -18,6 +18,7 @@
package socket
import (
+ "bytes"
"fmt"
"sync/atomic"
"syscall"
@@ -35,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -42,7 +44,79 @@ import (
// control messages.
type ControlMessages struct {
Unix transport.ControlMessages
- IP tcpip.ControlMessages
+ IP IPControlMessages
+}
+
+// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format.
+func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ return p
+}
+
+// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
+// have Linux specific format) to Linux format.
+func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
+ var orgDstAddr linux.SockAddr
+ if cmgs.HasOriginalDstAddress {
+ orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
+ }
+ return IPControlMessages{
+ HasTimestamp: cmgs.HasTimestamp,
+ Timestamp: cmgs.Timestamp,
+ HasInq: cmgs.HasInq,
+ Inq: cmgs.Inq,
+ HasTOS: cmgs.HasTOS,
+ TOS: cmgs.TOS,
+ HasTClass: cmgs.HasTClass,
+ TClass: cmgs.TClass,
+ HasIPPacketInfo: cmgs.HasIPPacketInfo,
+ PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ OriginalDstAddress: orgDstAddr,
+ }
+}
+
+// IPControlMessages contains socket control messages for IP sockets.
+// This can contain Linux specific structures unlike tcpip.ControlMessages.
+//
+// +stateify savable
+type IPControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo linux.ControlMessageIPPacketInfo
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress linux.SockAddr
}
// Release releases Unix domain socket credentials and rights.
@@ -460,3 +534,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
panic(fmt.Sprintf("Unsupported socket family %v", family))
}
}
+
+var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes()
+var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes()
+var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes()
+
+// Ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func Ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// Htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func Htons(v uint16) uint16 {
+ return Ntohs(v)
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return &out, uint32(2 + l)
+ }
+ return &out, uint32(3 + l)
+
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = Htons(addr.Port)
+ return &out, uint32(sockAddrInetSize)
+
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = Htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
+ default:
+ return nil, 0
+ }
+}
+
+// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func BytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
+//
+// AddressAndFamily returns an address and its family.
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family := usermem.ByteOrder.Uint16(addr); family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, family, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ return out, family, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, family, nil
+
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, family, nil
+
+ default:
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 6d9e502bd..9f7aca305 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -118,28 +118,24 @@ var (
// NewConnectioned creates a new unbound connectionedEndpoint.
func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
- return &connectionedEndpoint{
+ return newConnectioned(ctx, stype, uid)
+}
+
+func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint {
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
- a := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
- b := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
+ a := newConnectioned(ctx, stype, uid)
+ b := newConnectioned(ctx, stype, uid)
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
q1.InitRefs()
@@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
- return &connectionedEndpoint{
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// ID implements ConnectingEndpoint.ID.
@@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
+ ne.ops.InitHandler(ne)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
readQueue.InitRefs()
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 1406971bc..0813ad87d 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint {
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
+ ep.ops.InitHandler(ep)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 18a50e9f8..099a56281 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,8 +16,6 @@
package transport
import (
- "sync/atomic"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
@@ -180,10 +178,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool sets a socket option for simple cases when a value has
- // the int type.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -191,10 +185,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool gets a socket option for simple cases when a return
- // value has the int type.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
@@ -203,10 +193,11 @@ type Endpoint interface {
// procfs.
State() uint32
- // LastError implements tcpip.Endpoint.LastError.
+ // LastError clears and returns the last error reported by the endpoint.
LastError() *tcpip.Error
- // SocketOptions implements tcpip.Endpoint.SocketOptions.
+ // SocketOptions returns the structure which contains all the socket
+ // level options.
SocketOptions() *tcpip.SocketOptions
}
@@ -739,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() {
// +stateify savable
type baseEndpoint struct {
*waiter.Queue
-
- // passcred specifies whether SCM_CREDENTIALS socket control messages are
- // enabled on this endpoint. Must be accessed atomically.
- passcred int32
+ tcpip.DefaultSocketOptionsHandler
// Mutex protects the below fields.
sync.Mutex `state:"nosave"`
@@ -758,9 +746,7 @@ type baseEndpoint struct {
// or may be used if the endpoint is connected.
path string
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
+ // ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -786,7 +772,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
// Passcred implements Credentialer.Passcred.
func (e *baseEndpoint) Passcred() bool {
- return atomic.LoadInt32(&e.passcred) != 0
+ return e.SocketOptions().GetPassCred()
}
// ConnectedPasscred implements Credentialer.ConnectedPasscred.
@@ -796,14 +782,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool {
return e.connected != nil && e.connected.Passcred()
}
-func (e *baseEndpoint) setPasscred(pc bool) {
- if pc {
- atomic.StoreInt32(&e.passcred, 1)
- } else {
- atomic.StoreInt32(&e.passcred, 0)
- }
-}
-
// Connected implements ConnectingEndpoint.Connected.
func (e *baseEndpoint) Connected() bool {
return e.receiver != nil && e.connected != nil
@@ -859,23 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option.
func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- e.linger = *v
- e.Unlock()
- }
- return nil
-}
-
-func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.PasscredOption:
- e.setPasscred(v)
- case tcpip.ReuseAddressOption:
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- }
return nil
}
@@ -889,20 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- case tcpip.PasscredOption:
- return e.Passcred(), nil
-
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -966,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- *o = e.linger
- e.Unlock()
- return nil
-
- default:
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
- }
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
}
// LastError implements Endpoint.LastError.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 3e520d2ee..c59297c80 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -115,9 +115,6 @@ type socketOpsCommon struct {
// bound, they cannot be modified.
abstractName string
abstractNamespace *kernel.AbstractSocketNamespace
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
}
func (s *socketOpsCommon) isPacket() bool {
@@ -139,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, family, err := netstack.AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
@@ -172,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -184,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -258,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -650,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -685,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index eaf0b0d26..27f705bb2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index a920180d3..d36a64ffc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -32,8 +32,8 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/netlink",
- "//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
"//pkg/usermem",
],
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index cc5f70cd4..d943a7cb1 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
- "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := netstack.AddressAndFamily(b)
+ fa, _, err := socket.AddressAndFamily(b)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index bb1f715e2..b815e498f 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index 0bf313a13..c2285f796 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := ctx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := ctx.Prepare(); err != nil {
+ return err
}
if eventFile != nil {
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 519066a47..8db587401 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
- fSetOwn(t, file, set)
+ fSetOwn(t, int(fd), file, set)
return 0, nil, nil
case linux.FIOGETOWN, linux.SIOCGPGRP:
@@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 {
//
// If who is positive, it represents a PID. If negative, it represents a PGID.
// If the PID or PGID is invalid, the owner is silently unset.
-func fSetOwn(t *kernel.Task, file *fs.File, who int32) error {
- a := file.Async(fasync.New).(*fasync.FileAsync)
+func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error {
+ a := file.Async(fasync.New(fd)).(*fasync.FileAsync)
if who < 0 {
// Check for overflow before flipping the sign.
if who-1 > who {
@@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN:
return uintptr(fGetOwn(t, file)), nil, nil
case linux.F_SETOWN:
- return 0, nil, fSetOwn(t, file, args[2].Int())
+ return 0, nil, fSetOwn(t, int(fd), file, args[2].Int())
case linux.F_GETOWN_EX:
addr := args[2].Pointer()
owner := fGetOwnEx(t, file)
@@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- a := file.Async(fasync.New).(*fasync.FileAsync)
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
switch owner.Type {
case linux.F_OWNER_TID:
task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
@@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
n, err := sz.SetFifoSize(int64(args[2].Int()))
return uintptr(n), nil, err
+ case linux.F_GETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return uintptr(a.Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index e383a0a87..1166cd7bb 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -146,11 +146,37 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
v, err := getNCnt(t, id, num)
return uintptr(v), nil, err
- case linux.IPC_INFO,
- linux.SEM_INFO,
- linux.SEM_STAT,
- linux.SEM_STAT_ANY:
+ case linux.IPC_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.IPCInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+
+ case linux.SEM_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.SemInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+
+ case linux.SEM_STAT:
+ arg := args[3].Pointer()
+ // id is an index in SEM_STAT.
+ semid, ds, err := semStat(t, id)
+ if err != nil {
+ return 0, nil, err
+ }
+ if _, err := ds.CopyOut(t, arg); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(semid), nil, err
+ case linux.SEM_STAT_ANY:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
@@ -195,6 +221,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
return set.GetStat(creds)
}
+func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByIndex(index)
+ if set == nil {
+ return 0, nil, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ ds, err := set.GetStat(creds)
+ return set.ID, ds, err
+}
+
func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index e748d33d8..d639c9bf7 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(target.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(target.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
if err := target.SendGroupSignal(info); err != syserror.ESRCH {
return 0, nil, err
}
@@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
err := tg.SendSignal(info)
if err == syserror.ESRCH {
// ESRCH is ignored because it means the task
@@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
// See note above regarding ESRCH race above.
if err := tg.SendSignal(info); err != syserror.ESRCH {
lastErr = err
@@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI
Signo: int32(sig),
Code: arch.SignalInfoTkill,
}
- info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 983f8d396..8e7ac0ffe 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
si := arch.SignalInfo{
Signo: int32(linux.SIGCHLD),
}
- si.SetPid(int32(wr.TID))
- si.SetUid(int32(wr.UID))
+ si.SetPID(int32(wr.TID))
+ si.SetUID(int32(wr.UID))
// TODO(b/73541790): convert kernel.ExitStatus to functions and make
// WaitResult.Status a linux.WaitStatus.
s := syscall.WaitStatus(wr.Status)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index 6d0a38330..1365a5a62 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := aioCtx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := aioCtx.Prepare(); err != nil {
+ return err
}
if eventFD != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 36e89700e..7dd9ef857 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
case linux.F_GETOWN_EX:
owner, hasOwner := getAsyncOwner(t, file)
if !hasOwner {
@@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID)
+ return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID)
case linux.F_SETPIPE_SZ:
pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
if !ok {
@@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
case linux.F_SETLK, linux.F_SETLKW:
return 0, nil, posixLock(t, args, file, cmd)
+ case linux.F_GETSIG:
+ a := file.AsyncHandler()
+ if a == nil {
+ // Default behavior aka SIGIO.
+ return 0, nil, nil
+ }
+ return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
@@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne
}
}
-func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error {
+func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error {
switch ownerType {
case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
// Acceptable type.
@@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32
return syserror.EINVAL
}
- a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync)
+ a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync)
if pid == 0 {
a.ClearOwner()
return nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 2806c3f6f..20c264fef 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
}
ret, err := file.Ioctl(t, t.MemoryManager(), args)
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 440c9307c..a3868bf16 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -105,6 +105,7 @@ go_library(
"//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsmetric",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index a98aac52b..072655fe8 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
// Add epi to file.epolls so that it is removed when the last
@@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready with the new mask.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
return nil
@@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error
}
// Callback implements waiter.EntryCallback.Callback.
-func (epi *epollInterest) Callback(*waiter.Entry) {
+func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) {
newReady := false
epi.epoll.mu.Lock()
if !epi.ready {
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 936f9fc71..5321ac80a 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -15,12 +15,14 @@
package vfs
import (
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
@@ -42,7 +44,7 @@ import (
type FileDescription struct {
FileDescriptionRefs
- // flagsMu protects statusFlags and asyncHandler below.
+ // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below.
flagsMu sync.Mutex `state:"nosave"`
// statusFlags contains status flags, "initialized by open(2) and possibly
@@ -51,6 +53,11 @@ type FileDescription struct {
// access to asyncHandler.
statusFlags uint32
+ // saved is true after beforeSave is called. This is used to prevent
+ // double-unregistration of asyncHandler. This does not work properly for
+ // save-resume, which is not currently supported in gVisor (see b/26588733).
+ saved bool `state:"nosave"`
+
// asyncHandler handles O_ASYNC signal generation. It is set with the
// F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must
// also be set by fcntl(2).
@@ -183,7 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) {
}
fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
- if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
fd.asyncHandler.Unregister(fd)
}
fd.asyncHandler = nil
@@ -583,7 +590,11 @@ func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
if !fd.readable {
return 0, syserror.EBADF
}
- return fd.impl.PRead(ctx, dst, offset, opts)
+ start := fsmetric.StartReadWait()
+ n, err := fd.impl.PRead(ctx, dst, offset, opts)
+ fsmetric.Reads.Increment()
+ fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+ return n, err
}
// Read is similar to PRead, but does not specify an offset.
@@ -591,7 +602,11 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt
if !fd.readable {
return 0, syserror.EBADF
}
- return fd.impl.Read(ctx, dst, opts)
+ start := fsmetric.StartReadWait()
+ n, err := fd.impl.Read(ctx, dst, opts)
+ fsmetric.Reads.Increment()
+ fsmetric.FinishReadWait(fsmetric.ReadWait, start)
+ return n, err
}
// PWrite writes src to the file represented by fd, starting at the given
@@ -825,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn
return fd.asyncHandler
}
-// FileReadWriteSeeker is a helper struct to pass a FileDescription as
-// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
-type FileReadWriteSeeker struct {
- FD *FileDescription
- Ctx context.Context
- ROpts ReadOptions
- WOpts WriteOptions
-}
-
-// ReadAt implements io.ReaderAt.ReadAt.
-func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
- return int(n), err
-}
-
-// Read implements io.ReadWriteSeeker.Read.
-func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
- return int(n), err
-}
-
-// Seek implements io.ReadWriteSeeker.Seek.
-func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
- return f.FD.Seek(f.Ctx, offset, int32(whence))
-}
-
-// WriteAt implements io.WriterAt.WriteAt.
-func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
- return int(n), err
-}
-
-// Write implements io.ReadWriteSeeker.Write.
-func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
- buf := usermem.BytesIOSequence(p)
- n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
- return int(n), err
+// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD
+// returns EOF or an error. It returns the number of bytes copied.
+func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) {
+ done := int64(0)
+ buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
+ for {
+ readN, readErr := srcFD.Read(ctx, buf, ReadOptions{})
+ if readErr != nil && readErr != io.EOF {
+ return done, readErr
+ }
+ src := buf.TakeFirst64(readN)
+ for src.NumBytes() != 0 {
+ writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{})
+ done += writeN
+ src = src.DropFirst64(writeN)
+ if writeErr != nil {
+ return done, writeErr
+ }
+ }
+ if readErr == io.EOF {
+ return done, nil
+ }
+ }
}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index cb48c37a1..0df023713 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package vfs
import (
@@ -41,6 +36,15 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
+var (
+ mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
+ mountKeySeed = sync.RandUintptr()
+)
+
+func (k *mountKey) hash() uintptr {
+ return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
+}
+
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry {
}
}
-func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
-
// Invariant: mnt.key.parent == nil. vd.Ok().
func (mnt *Mount) setKey(vd VirtualDentry) {
atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
}
-func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
-
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
-//
-// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -84,8 +82,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount `state:"nosave"`
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -150,7 +147,6 @@ func init() {
// Init must be called exactly once on each mountTable before use.
func (mt *mountTable) Init() {
- mt.seed = rand32()
mt.size = mtInitOrder
mt.slots = newMountTableSlots(mtInitCap)
}
@@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := key.hash()
loop:
for {
@@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must not already contain a Mount with the same mount point and parent.
func (mt *mountTable) insertSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
// We're under the maximum load factor if:
//
@@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must contain mount.
func (mt *mountTable) removeSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
slots := mt.slots
@@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) {
off = (off + mountSlotBytes) & offmask
}
}
-
-//go:linkname memhash runtime.memhash
-func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
-
-//go:linkname rand32 runtime.fastrand
-func rand32() uint32
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
index 7723ed643..8998a82dd 100644
--- a/pkg/sentry/vfs/save_restore.go
+++ b/pkg/sentry/vfs/save_restore.go
@@ -18,8 +18,10 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// FilesystemImplSaveRestoreExtension is an optional extension to
@@ -99,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount {
return mounts
}
+// saveKey is called by stateify.
+func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
+
// loadMounts is called by stateify.
func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
if mounts == nil {
@@ -110,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
}
}
+// loadKey is called by stateify.
+func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
+
func (mnt *Mount) afterLoad() {
if atomic.LoadInt64(&mnt.refs) != 0 {
refsvfs2.Register(mnt)
@@ -120,5 +128,20 @@ func (mnt *Mount) afterLoad() {
func (epi *epollInterest) afterLoad() {
// Mark all epollInterests as ready after restore so that the next call to
// EpollInstance.ReadEvents() rechecks their readiness.
- epi.Callback(nil)
+ epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask))
+}
+
+// beforeSave is called by stateify.
+func (fd *FileDescription) beforeSave() {
+ fd.saved = true
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Unregister(fd)
+ }
+}
+
+// afterLoad is called by stateify.
+func (fd *FileDescription) afterLoad() {
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Register(fd)
+ }
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 48d6252f7..6fd1bb0b2 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -41,6 +41,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
@@ -381,6 +382,8 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
// OpenAt returns a FileDescription providing access to the file at the given
// path. A reference is taken on the returned FileDescription.
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
+ fsmetric.Opens.Increment()
+
// Remove:
//
// - O_CLOEXEC, which affects file descriptors and therefore must be
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 1d1062aeb..8e3146d8d 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -338,6 +338,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
tid := w.k.TaskSet().Root.IDOfTask(t)
buf.WriteString(fmt.Sprintf("\tTask tid: %v (goroutine %d), entered RunSys state %v ago.\n", tid, t.GoroutineID(), now.Sub(o.lastUpdateTime)))
}
+ buf.WriteString("Search for 'goroutine <id>' in the stack dump to find the offending goroutine(s)")
// Force stack dump only if a new task is detected.
w.doAction(w.TaskTimeoutAction, newTaskFound, &buf)
diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go
index d462c3eef..e8315326d 100644
--- a/pkg/shim/v1/proc/process.go
+++ b/pkg/shim/v1/proc/process.go
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package proc contains process-related utilities.
package proc
import (
diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD
index 05c595bc9..e5b6bf186 100644
--- a/pkg/shim/v1/shim/BUILD
+++ b/pkg/shim/v1/shim/BUILD
@@ -8,6 +8,7 @@ go_library(
"api.go",
"platform.go",
"service.go",
+ "shim.go",
],
visibility = [
"//pkg/shim:__subpackages__",
diff --git a/pkg/sleep/commit_asm.go b/pkg/shim/v1/shim/shim.go
index 75728a97d..1855a8769 100644
--- a/pkg/sleep/commit_asm.go
+++ b/pkg/shim/v1/shim/shim.go
@@ -1,10 +1,11 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
-// http://www.apache.org/licenses/LICENSE-2.0
+// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,9 +13,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 arm64
-
-package sleep
-
-// See commit_noasm.go for a description of commitSleep.
-func commitSleep(g uintptr, waitingG *uintptr) bool
+// Package shim contains the core containerd shim implementation.
+package shim
diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go
index 07e346654..21e75d16d 100644
--- a/pkg/shim/v1/utils/utils.go
+++ b/pkg/shim/v1/utils/utils.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package utils contains utility functions.
package utils
import (
diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD
index f37fefddc..b0e8daa51 100644
--- a/pkg/shim/v2/BUILD
+++ b/pkg/shim/v2/BUILD
@@ -22,6 +22,7 @@ go_library(
"//runsc/specutils",
"@com_github_burntsushi_toml//:go_default_library",
"@com_github_containerd_cgroups//:go_default_library",
+ "@com_github_containerd_cgroups//stats/v1:go_default_library",
"@com_github_containerd_console//:go_default_library",
"@com_github_containerd_containerd//api/events:go_default_library",
"@com_github_containerd_containerd//api/types/task:go_default_library",
diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go
index 2e39d2c4a..6aaf5fab8 100644
--- a/pkg/shim/v2/service.go
+++ b/pkg/shim/v2/service.go
@@ -28,6 +28,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/containerd/cgroups"
+ cgroupsstats "github.com/containerd/cgroups/stats/v1"
"github.com/containerd/console"
"github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/types/task"
@@ -67,9 +68,15 @@ var (
var _ = (taskAPI.TaskService)(&service{})
-// configFile is the default config file name. For containerd 1.2,
-// we assume that a config.toml should exist in the runtime root.
-const configFile = "config.toml"
+const (
+ // configFile is the default config file name. For containerd 1.2,
+ // we assume that a config.toml should exist in the runtime root.
+ configFile = "config.toml"
+
+ // shimAddressPath is the relative path to a file that contains the address
+ // to the shim UDS. See service.shimAddress.
+ shimAddressPath = "address"
+)
// New returns a new shim service that can be used via GRPC.
func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) {
@@ -101,6 +108,11 @@ func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()
return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
}
go s.forward(ctx, publisher)
+
+ if address, err := shim.ReadAddress(shimAddressPath); err == nil {
+ s.shimAddress = address
+ }
+
return s, nil
}
@@ -152,6 +164,9 @@ type service struct {
// cancel is a function that needs to be called before the shim stops. The
// function is provided by the caller to New().
cancel func()
+
+ // shimAddress is the location of the UDS used to communicate to containerd.
+ shimAddress string
}
func (s *service) newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) {
@@ -191,38 +206,58 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container
if err != nil {
return "", err
}
- address, err := shim.SocketAddress(ctx, id)
+ address, err := shim.SocketAddress(ctx, containerdAddress, id)
if err != nil {
return "", err
}
socket, err := shim.NewSocket(address)
if err != nil {
- return "", err
+ // The only time where this would happen is if there is a bug and the socket
+ // was not cleaned up in the cleanup method of the shim or we are using the
+ // grouping functionality where the new process should be run with the same
+ // shim as an existing container.
+ if !shim.SocketEaddrinuse(err) {
+ return "", fmt.Errorf("create new shim socket: %w", err)
+ }
+ if shim.CanConnect(address) {
+ if err := shim.WriteAddress(shimAddressPath, address); err != nil {
+ return "", fmt.Errorf("write existing socket for shim: %w", err)
+ }
+ return address, nil
+ }
+ if err := shim.RemoveSocket(address); err != nil {
+ return "", fmt.Errorf("remove pre-existing socket: %w", err)
+ }
+ if socket, err = shim.NewSocket(address); err != nil {
+ return "", fmt.Errorf("try create new shim socket 2x: %w", err)
+ }
}
- defer socket.Close()
+ cu := cleanup.Make(func() {
+ socket.Close()
+ _ = shim.RemoveSocket(address)
+ })
+ defer cu.Clean()
+
f, err := socket.File()
if err != nil {
return "", err
}
- defer f.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, f)
log.L.Debugf("Executing: %q %s", cmd.Path, cmd.Args)
if err := cmd.Start(); err != nil {
+ f.Close()
return "", err
}
- cu := cleanup.Make(func() {
- cmd.Process.Kill()
- })
- defer cu.Clean()
+ cu.Add(func() { cmd.Process.Kill() })
// make sure to wait after start
go cmd.Wait()
if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil {
return "", err
}
- if err := shim.WriteAddress("address", address); err != nil {
+ if err := shim.WriteAddress(shimAddressPath, address); err != nil {
return "", err
}
if err := shim.SetScore(cmd.Process.Pid); err != nil {
@@ -675,8 +710,11 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task
func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) {
log.L.Debugf("Shutdown, id: %s", r.ID)
s.cancel()
+ if s.shimAddress != "" {
+ _ = shim.RemoveSocket(s.shimAddress)
+ }
os.Exit(0)
- return empty, nil
+ panic("Should not get here")
}
func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) {
@@ -698,48 +736,48 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.
// as runc.
//
// [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81
- metrics := &cgroups.Metrics{
- CPU: &cgroups.CPUStat{
- Usage: &cgroups.CPUUsage{
+ metrics := &cgroupsstats.Metrics{
+ CPU: &cgroupsstats.CPUStat{
+ Usage: &cgroupsstats.CPUUsage{
Total: stats.Cpu.Usage.Total,
Kernel: stats.Cpu.Usage.Kernel,
User: stats.Cpu.Usage.User,
PerCPU: stats.Cpu.Usage.Percpu,
},
- Throttling: &cgroups.Throttle{
+ Throttling: &cgroupsstats.Throttle{
Periods: stats.Cpu.Throttling.Periods,
ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods,
ThrottledTime: stats.Cpu.Throttling.ThrottledTime,
},
},
- Memory: &cgroups.MemoryStat{
+ Memory: &cgroupsstats.MemoryStat{
Cache: stats.Memory.Cache,
- Usage: &cgroups.MemoryEntry{
+ Usage: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Usage.Limit,
Usage: stats.Memory.Usage.Usage,
Max: stats.Memory.Usage.Max,
Failcnt: stats.Memory.Usage.Failcnt,
},
- Swap: &cgroups.MemoryEntry{
+ Swap: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Swap.Limit,
Usage: stats.Memory.Swap.Usage,
Max: stats.Memory.Swap.Max,
Failcnt: stats.Memory.Swap.Failcnt,
},
- Kernel: &cgroups.MemoryEntry{
+ Kernel: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Kernel.Limit,
Usage: stats.Memory.Kernel.Usage,
Max: stats.Memory.Kernel.Max,
Failcnt: stats.Memory.Kernel.Failcnt,
},
- KernelTCP: &cgroups.MemoryEntry{
+ KernelTCP: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.KernelTCP.Limit,
Usage: stats.Memory.KernelTCP.Usage,
Max: stats.Memory.KernelTCP.Max,
Failcnt: stats.Memory.KernelTCP.Failcnt,
},
},
- Pids: &cgroups.PidsStat{
+ Pids: &cgroupsstats.PidsStat{
Current: stats.Pids.Current,
Limit: stats.Pids.Limit,
},
@@ -843,9 +881,7 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er
func (s *service) forward(ctx context.Context, publisher shim.Publisher) {
for e := range s.events {
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := publisher.Publish(ctx, getTopic(e), e)
- cancel()
if err != nil {
// Should not happen.
panic(fmt.Errorf("post event: %w", err))
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index ae0fe1522..48bcdd62b 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -5,10 +5,6 @@ package(licenses = ["notice"])
go_library(
name = "sleep",
srcs = [
- "commit_amd64.s",
- "commit_arm64.s",
- "commit_asm.go",
- "commit_noasm.go",
"sleep_unsafe.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/sleep/commit_amd64.s b/pkg/sleep/commit_amd64.s
deleted file mode 100644
index bc4ac2c3c..000000000
--- a/pkg/sleep/commit_amd64.s
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "textflag.h"
-
-#define preparingG 1
-
-// See commit_noasm.go for a description of commitSleep.
-//
-// func commitSleep(g uintptr, waitingG *uintptr) bool
-TEXT ·commitSleep(SB),NOSPLIT,$0-24
- MOVQ waitingG+8(FP), CX
- MOVQ g+0(FP), DX
-
- // Store the G in waitingG if it's still preparingG. If it's anything
- // else it means a waker has aborted the sleep.
- MOVQ $preparingG, AX
- LOCK
- CMPXCHGQ DX, 0(CX)
-
- SETEQ AX
- MOVB AX, ret+16(FP)
-
- RET
diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s
deleted file mode 100644
index d0ef15b20..000000000
--- a/pkg/sleep/commit_arm64.s
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "textflag.h"
-
-#define preparingG 1
-
-// See commit_noasm.go for a description of commitSleep.
-//
-// func commitSleep(g uintptr, waitingG *uintptr) bool
-TEXT ·commitSleep(SB),NOSPLIT,$0-24
- MOVD waitingG+8(FP), R0
- MOVD $preparingG, R1
- MOVD G+0(FP), R2
-
- // Store the G in waitingG if it's still preparingG. If it's anything
- // else it means a waker has aborted the sleep.
-again:
- LDAXR (R0), R3
- CMP R1, R3
- BNE ok
- STLXR R2, (R0), R3
- CBNZ R3, again
-ok:
- CSET EQ, R0
- MOVB R0, ret+16(FP)
- RET
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
deleted file mode 100644
index f59061f37..000000000
--- a/pkg/sleep/commit_noasm.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build !race
-// +build !amd64,!arm64
-
-package sleep
-
-import "sync/atomic"
-
-// commitSleep signals to wakers that the given g is now sleeping. Wakers can
-// then fetch it and wake it.
-//
-// The commit may fail if wakers have been asserted after our last check, in
-// which case they will have set s.waitingG to zero.
-//
-// It is written in assembly because it is called from g0, so it doesn't have
-// a race context.
-func commitSleep(g uintptr, waitingG *uintptr) bool {
- // Try to store the G so that wakers know who to wake.
- return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
-}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index 19bce2afb..c44206b1e 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.11
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
// Package sleep allows goroutines to efficiently sleep on multiple sources of
// notifications (wakers). It offers O(1) complexity, which is different from
// multi-channel selects which have O(n) complexity (where n is the number of
@@ -91,12 +86,6 @@ var (
assertedSleeper Sleeper
)
-//go:linkname gopark runtime.gopark
-func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason uint8, traceEv byte, traceskip int)
-
-//go:linkname goready runtime.goready
-func goready(g uintptr, traceskip int)
-
// Sleeper allows a goroutine to sleep and receive wake up notifications from
// Wakers in an efficient way.
//
@@ -189,7 +178,7 @@ func (s *Sleeper) nextWaker(block bool) *Waker {
// See:runtime2.go in the go runtime package for
// the values to pass as the waitReason here.
const waitReasonSelect = 9
- gopark(commitSleep, &s.waitingG, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(commitSleep, unsafe.Pointer(&s.waitingG), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
}
// Pull the shared list out and reverse it in the local
@@ -212,6 +201,18 @@ func (s *Sleeper) nextWaker(block bool) *Waker {
return w
}
+// commitSleep signals to wakers that the given g is now sleeping. Wakers can
+// then fetch it and wake it.
+//
+// The commit may fail if wakers have been asserted after our last check, in
+// which case they will have set s.waitingG to zero.
+//
+//go:norace
+//go:nosplit
+func commitSleep(g uintptr, waitingG unsafe.Pointer) bool {
+ return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(waitingG), preparingG, g)
+}
+
// Fetch fetches the next wake-up notification. If a notification is immediately
// available, it is returned right away. Otherwise, the behavior depends on the
// value of 'block': if true, the current goroutine blocks until a notification
@@ -311,7 +312,7 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
case 0, preparingG:
default:
// We managed to get a G. Wake it up.
- goready(g, 0)
+ sync.Goready(g, 0)
}
}
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
index d3931c952..2b1609af0 100644
--- a/pkg/state/tests/integer_test.go
+++ b/pkg/state/tests/integer_test.go
@@ -20,21 +20,21 @@ import (
)
var (
- allIntTs = []int{-1, 0, 1}
- allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
- allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
- allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
- allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
- allUintTs = []uint{0, 1}
- allUintptrs = []uintptr{0, 1, ^uintptr(0)}
- allUint8s = []uint8{0, 1, math.MaxUint8}
- allUint16s = []uint16{0, 1, math.MaxUint16}
- allUint32s = []uint32{0, 1, math.MaxUint32}
- allUint64s = []uint64{0, 1, math.MaxUint64}
+ allBasicInts = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allBasicUints = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
)
var allInts = flatten(
- allIntTs,
+ allBasicInts,
allInt8s,
allInt16s,
allInt32s,
@@ -42,7 +42,7 @@ var allInts = flatten(
)
var allUints = flatten(
- allUintTs,
+ allBasicUints,
allUintptrs,
allUint8s,
allUint16s,
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 68535c3b1..28e62abbb 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -10,15 +10,34 @@ exports_files(["LICENSE"])
go_template(
name = "generic_atomicptr",
- srcs = ["atomicptr_unsafe.go"],
+ srcs = ["generic_atomicptr_unsafe.go"],
types = [
"Value",
],
)
go_template(
+ name = "generic_atomicptrmap",
+ srcs = ["generic_atomicptrmap_unsafe.go"],
+ opt_consts = [
+ "ShardOrder",
+ ],
+ opt_types = [
+ "Hasher",
+ ],
+ types = [
+ "Key",
+ "Value",
+ ],
+ deps = [
+ ":sync",
+ "//pkg/gohacks",
+ ],
+)
+
+go_template(
name = "generic_seqatomic",
- srcs = ["seqatomic_unsafe.go"],
+ srcs = ["generic_seqatomic_unsafe.go"],
types = [
"Value",
],
@@ -31,18 +50,26 @@ go_library(
name = "sync",
srcs = [
"aliases.go",
- "memmove_unsafe.go",
+ "checklocks_off_unsafe.go",
+ "checklocks_on_unsafe.go",
+ "goyield_go113_unsafe.go",
+ "goyield_unsafe.go",
"mutex_unsafe.go",
"nocopy.go",
"norace_unsafe.go",
+ "race_amd64.s",
+ "race_arm64.s",
"race_unsafe.go",
+ "runtime_unsafe.go",
"rwmutex_unsafe.go",
"seqcount.go",
- "spin_unsafe.go",
"sync.go",
],
marshal = False,
stateify = False,
+ deps = [
+ "//pkg/goid",
+ ],
)
go_test(
diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmaptest/BUILD
new file mode 100644
index 000000000..3f71ae97d
--- /dev/null
+++ b/pkg/sync/atomicptrmaptest/BUILD
@@ -0,0 +1,57 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:private"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "test_atomicptrmap",
+ out = "test_atomicptrmap_unsafe.go",
+ package = "atomicptrmap",
+ prefix = "test",
+ template = "//pkg/sync:generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_template_instance(
+ name = "test_atomicptrmap_sharded",
+ out = "test_atomicptrmap_sharded_unsafe.go",
+ consts = {
+ "ShardOrder": "4",
+ },
+ package = "atomicptrmap",
+ prefix = "test",
+ suffix = "Sharded",
+ template = "//pkg/sync:generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_library(
+ name = "atomicptrmap",
+ testonly = 1,
+ srcs = [
+ "atomicptrmap.go",
+ "test_atomicptrmap_sharded_unsafe.go",
+ "test_atomicptrmap_unsafe.go",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "atomicptrmap_test",
+ size = "small",
+ srcs = ["atomicptrmap_test.go"],
+ library = ":atomicptrmap",
+ deps = ["//pkg/sync"],
+)
diff --git a/tools/vm/test.cc b/pkg/sync/atomicptrmaptest/atomicptrmap.go
index c0ceacda1..867821ce9 100644
--- a/tools/vm/test.cc
+++ b/pkg/sync/atomicptrmaptest/atomicptrmap.go
@@ -12,16 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "gtest/gtest.h"
+// Package atomicptrmap instantiates generic_atomicptrmap for testing.
+package atomicptrmap
-namespace {
-
-TEST(Image, Sanity0) {
- // Do nothing (in shard 0).
-}
-
-TEST(Image, Sanity1) {
- // Do nothing (in shard 1).
+type testValue struct {
+ val int
}
-
-} // namespace
diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
new file mode 100644
index 000000000..75a9997ef
--- /dev/null
+++ b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
@@ -0,0 +1,635 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package atomicptrmap
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestConsistencyWithGoMap(t *testing.T) {
+ const maxKey = 16
+ var vals [4]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ vals[i] = new(testValue)
+ }
+ var (
+ m = make(map[int64]*testValue)
+ apm testAtomicPtrMap
+ )
+ for i := 0; i < 100000; i++ {
+ // Apply a random operation to both m and apm and expect them to have
+ // the same result. Bias toward CompareAndSwap, which has the most
+ // cases; bias away from Range and RangeRepeatable, which are
+ // relatively expensive.
+ switch rand.Intn(10) {
+ case 0, 1: // Load
+ key := rand.Int63n(maxKey)
+ want := m[key]
+ got := apm.Load(key)
+ t.Logf("Load(%d) = %p", key, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 2, 3: // Swap
+ key := rand.Int63n(maxKey)
+ val := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if val != nil {
+ m[key] = val
+ } else {
+ delete(m, key)
+ }
+ got := apm.Swap(key, val)
+ t.Logf("Swap(%d, %p) = %p", key, val, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 4, 5, 6, 7: // CompareAndSwap
+ key := rand.Int63n(maxKey)
+ oldVal := vals[rand.Intn(len(vals))]
+ newVal := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if want == oldVal {
+ if newVal != nil {
+ m[key] = newVal
+ } else {
+ delete(m, key)
+ }
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 8: // Range
+ got := make(map[int64]*testValue)
+ var (
+ haveDup = false
+ dup int64
+ )
+ apm.Range(func(key int64, val *testValue) bool {
+ if _, ok := got[key]; ok && !haveDup {
+ haveDup = true
+ dup = key
+ }
+ got[key] = val
+ return true
+ })
+ t.Logf("Range() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ if haveDup {
+ t.Fatalf("got duplicate key %d", dup)
+ }
+ case 9: // RangeRepeatable
+ got := make(map[int64]*testValue)
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ got[key] = val
+ return true
+ })
+ t.Logf("RangeRepeatable() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ }
+ }
+}
+
+func TestConcurrentHeterogeneous(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ var (
+ apm testAtomicPtrMap
+ wg sync.WaitGroup
+ )
+ defer func() {
+ cancel()
+ wg.Wait()
+ }()
+
+ possibleKeyValuePairs := make(map[int64]map[*testValue]struct{})
+ addKeyValuePair := func(key int64, val *testValue) {
+ values := possibleKeyValuePairs[key]
+ if values == nil {
+ values = make(map[*testValue]struct{})
+ possibleKeyValuePairs[key] = values
+ }
+ values[val] = struct{}{}
+ }
+
+ const numValuesPerKey = 4
+
+ // These goroutines use keys not used by any other goroutine.
+ const numPrivateKeys = 3
+ for i := 0; i < numPrivateKeys; i++ {
+ key := int64(i)
+ var vals [numValuesPerKey]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ val := new(testValue)
+ vals[i] = val
+ addKeyValuePair(key, val)
+ }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ var stored *testValue
+ for ctx.Err() == nil {
+ switch r.Intn(4) {
+ case 0:
+ got := apm.Load(key)
+ if got != stored {
+ t.Errorf("Load(%d): got %p, wanted %p", key, got, stored)
+ return
+ }
+ case 1:
+ val := vals[r.Intn(len(vals))]
+ want := stored
+ stored = val
+ got := apm.Swap(key, val)
+ if got != want {
+ t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want)
+ return
+ }
+ case 2, 3:
+ oldVal := vals[r.Intn(len(vals))]
+ newVal := vals[r.Intn(len(vals))]
+ want := stored
+ if stored == oldVal {
+ stored = newVal
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ if got != want {
+ t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // These goroutines share a small set of keys.
+ const numSharedKeys = 2
+ var (
+ sharedKeys [numSharedKeys]int64
+ sharedValues = make(map[int64][]*testValue)
+ sharedValuesSet = make(map[int64]map[*testValue]struct{})
+ )
+ for i := range sharedKeys {
+ key := int64(numPrivateKeys + i)
+ sharedKeys[i] = key
+ vals := make([]*testValue, numValuesPerKey)
+ valsSet := make(map[*testValue]struct{})
+ for j := range vals {
+ val := new(testValue)
+ vals[j] = val
+ valsSet[val] = struct{}{}
+ addKeyValuePair(key, val)
+ }
+ sharedValues[key] = vals
+ sharedValuesSet[key] = valsSet
+ }
+ randSharedValue := func(r *rand.Rand, key int64) *testValue {
+ vals := sharedValues[key]
+ return vals[r.Intn(len(vals))]
+ }
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ keyIndex := r.Intn(len(sharedKeys))
+ key := sharedKeys[keyIndex]
+ var (
+ op string
+ got *testValue
+ )
+ switch r.Intn(4) {
+ case 0:
+ op = "Load"
+ got = apm.Load(key)
+ case 1:
+ op = "Swap"
+ got = apm.Swap(key, randSharedValue(r, key))
+ case 2, 3:
+ op = "CompareAndSwap"
+ got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key))
+ }
+ if got != nil {
+ valsSet := sharedValuesSet[key]
+ if _, ok := valsSet[got]; !ok {
+ t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // This goroutine repeatedly searches for unused keys.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ key := -1 - r.Int63()
+ if got := apm.Load(key); got != nil {
+ t.Errorf("Load(%d): got %p, wanted nil", key, got)
+ }
+ }
+ }()
+
+ // This goroutine repeatedly calls RangeRepeatable() and checks that each
+ // key corresponds to an expected value.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ abort := false
+ for !abort && ctx.Err() == nil {
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("RangeRepeatable: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ return true
+ })
+ }
+ }()
+
+ // Finally, the main goroutine spins for the length of the test calling
+ // Range() and checking that each key that it observes is unique and
+ // corresponds to an expected value.
+ seenKeys := make(map[int64]struct{})
+ const testDuration = 5 * time.Second
+ end := time.Now().Add(testDuration)
+ abort := false
+ for time.Now().Before(end) {
+ apm.Range(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("Range: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ if _, ok := seenKeys[key]; ok {
+ t.Errorf("Range: got duplicate key %d", key)
+ abort = true
+ return false
+ }
+ seenKeys[key] = struct{}{}
+ return true
+ })
+ if abort {
+ break
+ }
+ for k := range seenKeys {
+ delete(seenKeys, k)
+ }
+ }
+}
+
+type benchmarkableMap interface {
+ Load(key int64) *testValue
+ Store(key int64, val *testValue)
+ LoadOrStore(key int64, val *testValue) (*testValue, bool)
+ Delete(key int64)
+}
+
+// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map.
+type rwMutexMap struct {
+ mu sync.RWMutex
+ m map[int64]*testValue
+}
+
+func (m *rwMutexMap) Load(key int64) *testValue {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.m[key]
+}
+
+func (m *rwMutexMap) Store(key int64, val *testValue) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ m.m[key] = val
+}
+
+func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ if oldVal, ok := m.m[key]; ok {
+ return oldVal, true
+ }
+ m.m[key] = val
+ return val, false
+}
+
+func (m *rwMutexMap) Delete(key int64) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ delete(m.m, key)
+}
+
+// syncMap implements benchmarkableMap for a sync.Map.
+type syncMap struct {
+ m sync.Map
+}
+
+func (m *syncMap) Load(key int64) *testValue {
+ val, ok := m.m.Load(key)
+ if !ok {
+ return nil
+ }
+ return val.(*testValue)
+}
+
+func (m *syncMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ actual, loaded := m.m.LoadOrStore(key, val)
+ return actual.(*testValue), loaded
+}
+
+func (m *syncMap) Delete(key int64) {
+ m.m.Delete(key)
+}
+
+// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap.
+type benchmarkableAtomicPtrMap struct {
+ m testAtomicPtrMap
+}
+
+func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMap) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded.
+type benchmarkableAtomicPtrMapSharded struct {
+ m testAtomicPtrMapSharded
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+var mapImpls = [...]struct {
+ name string
+ ctor func() benchmarkableMap
+}{
+ {
+ name: "RWMutexMap",
+ ctor: func() benchmarkableMap {
+ return new(rwMutexMap)
+ },
+ },
+ {
+ name: "SyncMap",
+ ctor: func() benchmarkableMap {
+ return new(syncMap)
+ },
+ },
+ {
+ name: "AtomicPtrMap",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMap)
+ },
+ },
+ {
+ name: "AtomicPtrMapSharded",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMapSharded)
+ },
+ },
+}
+
+func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.LoadOrStore(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkLoadOrStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLoadOrStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(i))
+ }
+}
+
+func BenchmarkLookupPositive(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupPositive(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(-1 - i))
+ }
+}
+
+func BenchmarkLookupNegative(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupNegative(b, mapImpl.ctor)
+ })
+ }
+}
+
+type benchmarkConcurrentOptions struct {
+ // loadsPerMutationPair is the number of map lookups between each
+ // insertion/deletion pair.
+ loadsPerMutationPair int
+
+ // If changeKeys is true, the keys used by each goroutine change between
+ // iterations of the test.
+ changeKeys bool
+}
+
+func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) {
+ var (
+ started sync.WaitGroup
+ workers sync.WaitGroup
+ )
+ started.Add(1)
+
+ m := mapCtor()
+ val := &testValue{}
+ // Insert a large number of unused elements into the map so that used
+ // elements are distributed throughout memory.
+ for i := 0; i < 10000; i++ {
+ m.Store(int64(-1-i), val)
+ }
+ // n := ceil(b.N / (opts.loadsPerMutationPair + 2))
+ n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2)
+ for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ {
+ workerID := i
+ workers.Add(1)
+ go func() {
+ defer workers.Done()
+ started.Wait()
+ for i := 0; i < n; i++ {
+ var key int64
+ if opts.changeKeys {
+ key = int64(workerID*n + i)
+ } else {
+ key = int64(workerID)
+ }
+ m.LoadOrStore(key, val)
+ for j := 0; j < opts.loadsPerMutationPair; j++ {
+ m.Load(key)
+ }
+ m.Delete(key)
+ }
+ }()
+ }
+
+ b.ResetTimer()
+ started.Done()
+ workers.Wait()
+}
+
+func BenchmarkConcurrent(b *testing.B) {
+ changeKeysChoices := [...]struct {
+ name string
+ val bool
+ }{
+ {"FixedKeys", false},
+ {"ChangingKeys", true},
+ }
+ writePcts := [...]struct {
+ name string
+ loadsPerMutationPair int
+ }{
+ {"1PercentWrites", 198},
+ {"10PercentWrites", 18},
+ {"50PercentWrites", 2},
+ }
+ for _, changeKeys := range changeKeysChoices {
+ for _, writePct := range writePcts {
+ for _, mapImpl := range mapImpls {
+ name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name)
+ b.Run(name, func(b *testing.B) {
+ benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{
+ loadsPerMutationPair: writePct.loadsPerMutationPair,
+ changeKeys: changeKeys.val,
+ })
+ })
+ }
+ }
+ }
+}
diff --git a/pkg/sync/checklocks_off_unsafe.go b/pkg/sync/checklocks_off_unsafe.go
new file mode 100644
index 000000000..62c81b149
--- /dev/null
+++ b/pkg/sync/checklocks_off_unsafe.go
@@ -0,0 +1,18 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !checklocks
+
+package sync
+
+import (
+ "unsafe"
+)
+
+func noteLock(l unsafe.Pointer) {
+}
+
+func noteUnlock(l unsafe.Pointer) {
+}
diff --git a/pkg/sync/checklocks_on_unsafe.go b/pkg/sync/checklocks_on_unsafe.go
new file mode 100644
index 000000000..24f933ed1
--- /dev/null
+++ b/pkg/sync/checklocks_on_unsafe.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build checklocks
+
+package sync
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/goid"
+)
+
+// gLocks contains metadata about the locks held by a goroutine.
+type gLocks struct {
+ locksHeld []unsafe.Pointer
+}
+
+// map[goid int]*gLocks
+//
+// Each key may only be written by the G with the goid it refers to.
+//
+// Note that entries are not evicted when a G exit, causing unbounded growth
+// with new G creation / destruction. If this proves problematic, entries could
+// be evicted when no locks are held at the expense of more allocations when
+// taking top-level locks.
+var locksHeld sync.Map
+
+func getGLocks() *gLocks {
+ id := goid.Get()
+
+ var locks *gLocks
+ if l, ok := locksHeld.Load(id); ok {
+ locks = l.(*gLocks)
+ } else {
+ locks = &gLocks{
+ // Initialize space for a few locks.
+ locksHeld: make([]unsafe.Pointer, 0, 8),
+ }
+ locksHeld.Store(id, locks)
+ }
+
+ return locks
+}
+
+func noteLock(l unsafe.Pointer) {
+ locks := getGLocks()
+
+ for _, lock := range locks.locksHeld {
+ if lock == l {
+ panic(fmt.Sprintf("Deadlock on goroutine %d! Double lock of %p: %+v", goid.Get(), l, locks))
+ }
+ }
+
+ // Commit only after checking for panic conditions so that this lock
+ // isn't on the list if the above panic is recovered.
+ locks.locksHeld = append(locks.locksHeld, l)
+}
+
+func noteUnlock(l unsafe.Pointer) {
+ locks := getGLocks()
+
+ if len(locks.locksHeld) == 0 {
+ panic(fmt.Sprintf("Unlock of %p on goroutine %d without any locks held! All locks:\n%s", l, goid.Get(), dumpLocks()))
+ }
+
+ // Search backwards since callers are most likely to unlock in LIFO order.
+ length := len(locks.locksHeld)
+ for i := length - 1; i >= 0; i-- {
+ if l == locks.locksHeld[i] {
+ copy(locks.locksHeld[i:length-1], locks.locksHeld[i+1:length])
+ // Clear last entry to ensure addr can be GC'd.
+ locks.locksHeld[length-1] = nil
+ locks.locksHeld = locks.locksHeld[:length-1]
+ return
+ }
+ }
+
+ panic(fmt.Sprintf("Unlock of %p on goroutine %d without matching lock! All locks:\n%s", l, goid.Get(), dumpLocks()))
+}
+
+func dumpLocks() string {
+ var s strings.Builder
+ locksHeld.Range(func(key, value interface{}) bool {
+ goid := key.(int64)
+ locks := value.(*gLocks)
+
+ // N.B. accessing gLocks of another G is fundamentally racy.
+
+ fmt.Fprintf(&s, "goroutine %d:\n", goid)
+ if len(locks.locksHeld) == 0 {
+ fmt.Fprintf(&s, "\t<none>\n")
+ }
+ for _, lock := range locks.locksHeld {
+ fmt.Fprintf(&s, "\t%p\n", lock)
+ }
+ fmt.Fprintf(&s, "\n")
+
+ return true
+ })
+
+ return s.String()
+}
diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/generic_atomicptr_unsafe.go
index 525c4beed..82b6df18c 100644
--- a/pkg/sync/atomicptr_unsafe.go
+++ b/pkg/sync/generic_atomicptr_unsafe.go
@@ -3,9 +3,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package template doesn't exist. This file must be instantiated using the
+// Package seqatomic doesn't exist. This file must be instantiated using the
// go_template_instance rule in tools/go_generics/defs.bzl.
-package template
+package seqatomic
import (
"sync/atomic"
diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/generic_atomicptrmap_unsafe.go
new file mode 100644
index 000000000..c70dda6dd
--- /dev/null
+++ b/pkg/sync/generic_atomicptrmap_unsafe.go
@@ -0,0 +1,503 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package atomicptrmap doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package atomicptrmap
+
+import (
+ "reflect"
+ "runtime"
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Key is a required type parameter.
+type Key struct{}
+
+// Value is a required type parameter.
+type Value struct{}
+
+const (
+ // ShardOrder is an optional parameter specifying the base-2 log of the
+ // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce
+ // unnecessary synchronization between unrelated concurrent operations,
+ // improving performance for write-heavy workloads, but increase memory
+ // usage for small maps.
+ ShardOrder = 0
+)
+
+// Hasher is an optional type parameter. If Hasher is provided, it must define
+// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps.
+type Hasher struct {
+ defaultHasher
+}
+
+// defaultHasher is the default Hasher. This indirection exists because
+// defaultHasher must exist even if a custom Hasher is provided, to prevent the
+// Go compiler from complaining about defaultHasher's unused imports.
+type defaultHasher struct {
+ fn func(unsafe.Pointer, uintptr) uintptr
+ seed uintptr
+}
+
+// Init initializes the Hasher.
+func (h *defaultHasher) Init() {
+ h.fn = sync.MapKeyHasher(map[Key]*Value(nil))
+ h.seed = sync.RandUintptr()
+}
+
+// Hash returns the hash value for the given Key.
+func (h *defaultHasher) Hash(key Key) uintptr {
+ return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed)
+}
+
+var hasher Hasher
+
+func init() {
+ hasher.Init()
+}
+
+// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are
+// safe for concurrent use from multiple goroutines without additional
+// synchronization.
+//
+// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for
+// use. AtomicPtrMaps must not be copied after first use.
+//
+// sync.Map may be faster than AtomicPtrMap if most operations on the map are
+// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in
+// other circumstances.
+type AtomicPtrMap struct {
+ // AtomicPtrMap is implemented as a hash table with the following
+ // properties:
+ //
+ // * Collisions are resolved with quadratic probing. Of the two major
+ // alternatives, Robin Hood linear probing makes it difficult for writers
+ // to execute in parallel, and bucketing is less effective in Go due to
+ // lack of SIMD.
+ //
+ // * The table is optionally divided into shards indexed by hash to further
+ // reduce unnecessary synchronization.
+
+ shards [1 << ShardOrder]apmShard
+}
+
+func (m *AtomicPtrMap) shard(hash uintptr) *apmShard {
+ // Go defines right shifts >= width of shifted unsigned operand as 0, so
+ // this is correct even if ShardOrder is 0 (although nogo complains because
+ // nogo is dumb).
+ const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder
+ index := hash >> indexLSB
+ return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{}))))
+}
+
+type apmShard struct {
+ apmShardMutationData
+ _ [apmShardMutationDataPadding]byte
+ apmShardLookupData
+ _ [apmShardLookupDataPadding]byte
+}
+
+type apmShardMutationData struct {
+ dirtyMu sync.Mutex // serializes slot transitions out of empty
+ dirty uintptr // # slots with val != nil
+ count uintptr // # slots with val != nil and val != tombstone()
+ rehashMu sync.Mutex // serializes rehashing
+}
+
+type apmShardLookupData struct {
+ seq sync.SeqCount // allows atomic reads of slots+mask
+ slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq
+ mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq
+}
+
+const (
+ cacheLineBytes = 64
+ // Cache line padding is enabled if sharding is.
+ apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise
+ // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) %
+ // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes).
+ apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1)
+ apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding
+ apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1)
+ apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding
+
+ // These define fractional thresholds for when apmShard.rehash() is called
+ // (i.e. the load factor) and when it rehases to a larger table
+ // respectively. They are chosen such that the rehash threshold = the
+ // expansion threshold + 1/2, so that when reuse of deleted slots is rare
+ // or non-existent, rehashing occurs after the insertion of at least 1/2
+ // the table's size in new entries, which is acceptably infrequent.
+ apmRehashThresholdNum = 2
+ apmRehashThresholdDen = 3
+ apmExpansionThresholdNum = 1
+ apmExpansionThresholdDen = 6
+)
+
+type apmSlot struct {
+ // slot states are indicated by val:
+ //
+ // * Empty: val == nil; key is meaningless. May transition to full or
+ // evacuated with dirtyMu locked.
+ //
+ // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val
+ // is the Value mapped to key. May transition to deleted or evacuated.
+ //
+ // * Deleted: val == tombstone(); key is still immutable. key is mapped to
+ // no Value. May transition to full or evacuated.
+ //
+ // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on
+ // slots that have already been moved, requiring readers to wait for
+ // rehashing to complete and use the new table. Terminal state.
+ //
+ // Note that once val is non-nil, it cannot become nil again. That is, the
+ // transition from empty to non-empty is irreversible for a given slot;
+ // the only way to create more empty slots is by rehashing.
+ val unsafe.Pointer
+ key Key
+}
+
+func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot {
+ return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{})))
+}
+
+var tombstoneObj byte
+
+func tombstone() unsafe.Pointer {
+ return unsafe.Pointer(&tombstoneObj)
+}
+
+var evacuatedObj byte
+
+func evacuated() unsafe.Pointer {
+ return unsafe.Pointer(&evacuatedObj)
+}
+
+// Load returns the Value stored in m for key.
+func (m *AtomicPtrMap) Load(key Key) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ return nil
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Empty slot; end of probe sequence.
+ return nil
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// Store stores the Value val for key.
+func (m *AtomicPtrMap) Store(key Key, val *Value) {
+ m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// Swap stores the Value val for key and returns the previously-mapped Value.
+func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value {
+ return m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it
+// stores the Value newVal for key. CompareAndSwap returns the previous Value
+// stored for key, whether or not it stores newVal.
+func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value {
+ return m.maybeCompareAndSwap(key, true, oldVal, newVal)
+}
+
+func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+ oldVal := tombstone()
+ if typedOldVal != nil {
+ oldVal = unsafe.Pointer(typedOldVal)
+ }
+ newVal := tombstone()
+ if typedNewVal != nil {
+ newVal = unsafe.Pointer(typedNewVal)
+ }
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Need to allocate a table before insertion.
+ shard.rehash(nil)
+ goto retry
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Try to grab this slot for ourselves.
+ shard.dirtyMu.Lock()
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Check if we need to rehash before dirtying a slot.
+ if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum {
+ shard.dirtyMu.Unlock()
+ shard.rehash(slots)
+ goto retry
+ }
+ slot.key = key
+ atomic.StorePointer(&slot.val, newVal) // transitions slot to full
+ shard.dirty++
+ atomic.AddUintptr(&shard.count, 1)
+ shard.dirtyMu.Unlock()
+ return nil
+ }
+ // Raced with another store; the slot is no longer empty. Continue
+ // with the new value of slotVal since we may have raced with
+ // another store of key.
+ shard.dirtyMu.Unlock()
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ // We're reusing an existing slot, so rehashing isn't necessary.
+ for {
+ if (compare && oldVal != slotVal) || newVal == slotVal {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) {
+ if slotVal == tombstone() {
+ atomic.AddUintptr(&shard.count, 1)
+ return nil
+ }
+ if newVal == tombstone() {
+ atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */)
+ }
+ return (*Value)(slotVal)
+ }
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ }
+ }
+ // This produces a triangular number sequence of offsets from the
+ // initially-probed position.
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// rehash is marked nosplit to avoid preemption during table copying.
+//go:nosplit
+func (shard *apmShard) rehash(oldSlots unsafe.Pointer) {
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+
+ if shard.slots != oldSlots {
+ // Raced with another call to rehash().
+ return
+ }
+
+ // Determine the size of the new table. Constraints:
+ //
+ // * The size of the table must be a power of two to ensure that every slot
+ // is visitable by every probe sequence under quadratic probing with
+ // triangular numbers.
+ //
+ // * The size of the table cannot decrease because even if shard.count is
+ // currently smaller than shard.dirty, concurrent stores that reuse
+ // existing slots can drive shard.count back up to a maximum of
+ // shard.dirty.
+ newSize := uintptr(8) // arbitrary initial size
+ if oldSlots != nil {
+ oldSize := shard.mask + 1
+ newSize = oldSize
+ if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum {
+ newSize *= 2
+ }
+ }
+
+ // Allocate the new table.
+ newSlotsSlice := make([]apmSlot, newSize)
+ newSlotsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&newSlotsSlice))
+ newSlots := unsafe.Pointer(newSlotsReflect.Data)
+ runtime.KeepAlive(newSlotsSlice)
+ newMask := newSize - 1
+
+ // Start a writer critical section now so that racing users of the old
+ // table that observe evacuated() wait for the new table. (But lock dirtyMu
+ // first since doing so may block, which we don't want to do during the
+ // writer critical section.)
+ shard.dirtyMu.Lock()
+ shard.seq.BeginWrite()
+
+ if oldSlots != nil {
+ realCount := uintptr(0)
+ // Copy old entries to the new table.
+ oldMask := shard.mask
+ for i := uintptr(0); i <= oldMask; i++ {
+ oldSlot := apmSlotAt(oldSlots, i)
+ val := atomic.SwapPointer(&oldSlot.val, evacuated())
+ if val == nil || val == tombstone() {
+ continue
+ }
+ hash := hasher.Hash(oldSlot.key)
+ j := hash & newMask
+ inc := uintptr(1)
+ for {
+ newSlot := apmSlotAt(newSlots, j)
+ if newSlot.val == nil {
+ newSlot.val = val
+ newSlot.key = oldSlot.key
+ break
+ }
+ j = (j + inc) & newMask
+ inc++
+ }
+ realCount++
+ }
+ // Update dirty to reflect that tombstones were not copied to the new
+ // table. Use realCount since a concurrent mutator may not have updated
+ // shard.count yet.
+ shard.dirty = realCount
+ }
+
+ // Switch to the new table.
+ atomic.StorePointer(&shard.slots, newSlots)
+ atomic.StoreUintptr(&shard.mask, newMask)
+
+ shard.seq.EndWrite()
+ shard.dirtyMu.Unlock()
+}
+
+// Range invokes f on each Key-Value pair stored in m. If any call to f returns
+// false, Range stops iteration and returns.
+//
+// Range does not necessarily correspond to any consistent snapshot of the
+// Map's contents: no Key will be visited more than once, but if the Value for
+// any Key is stored or deleted concurrently, Range may reflect any mapping for
+// that Key from any point during the Range call.
+//
+// f must not call other methods on m.
+func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+ if !shard.doRange(f) {
+ return
+ }
+ }
+}
+
+func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool {
+ // We have to lock rehashMu because if we handled races with rehashing by
+ // retrying, f could see the same key twice.
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+ slots := shard.slots
+ if slots == nil {
+ return true
+ }
+ mask := shard.mask
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return false
+ }
+ }
+ return true
+}
+
+// RangeRepeatable is like Range, but:
+//
+// * RangeRepeatable may visit the same Key multiple times in the presence of
+// concurrent mutators, possibly passing different Values to f in different
+// calls.
+//
+// * It is safe for f to call other methods on m.
+func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+
+ retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ continue
+ }
+
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return
+ }
+ }
+ }
+}
diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go
index 2184cb5ab..82b676abf 100644
--- a/pkg/sync/seqatomic_unsafe.go
+++ b/pkg/sync/generic_seqatomic_unsafe.go
@@ -3,25 +3,17 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package template doesn't exist. This file must be instantiated using the
+// Package seqatomic doesn't exist. This file must be instantiated using the
// go_template_instance rule in tools/go_generics/defs.bzl.
-package template
+package seqatomic
import (
- "fmt"
- "reflect"
- "strings"
"unsafe"
"gvisor.dev/gvisor/pkg/sync"
)
// Value is a required type parameter.
-//
-// Value must not contain any pointers, including interface objects, function
-// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs
-// containing any of the above. An init() function will panic if this property
-// does not hold.
type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
@@ -55,12 +47,3 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value)
ok = seq.ReadOk(epoch)
return
}
-
-func init() {
- var val Value
- typ := reflect.TypeOf(val)
- name := typ.Name()
- if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
- panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
- }
-}
diff --git a/pkg/sync/goyield_go113_unsafe.go b/pkg/sync/goyield_go113_unsafe.go
new file mode 100644
index 000000000..8aee0d455
--- /dev/null
+++ b/pkg/sync/goyield_go113_unsafe.go
@@ -0,0 +1,18 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.14
+
+package sync
+
+import (
+ "runtime"
+)
+
+func goyield() {
+ // goyield is not available until Go 1.14.
+ runtime.Gosched()
+}
diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/goyield_unsafe.go
index cafb2d065..672ee274d 100644
--- a/pkg/sync/spin_unsafe.go
+++ b/pkg/sync/goyield_unsafe.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build go1.13
+// +build go1.14
// +build !go1.17
// Check go:linkname function signatures when updating Go version.
@@ -14,11 +14,5 @@ import (
_ "unsafe" // for go:linkname
)
-//go:linkname canSpin sync.runtime_canSpin
-func canSpin(i int) bool
-
-//go:linkname doSpin sync.runtime_doSpin
-func doSpin()
-
//go:linkname goyield runtime.goyield
func goyield()
diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
deleted file mode 100644
index f5e630009..000000000
--- a/pkg/sync/memmove_unsafe.go
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
-package sync
-
-import (
- "unsafe"
-)
-
-//go:linkname memmove runtime.memmove
-//go:noescape
-func memmove(to, from unsafe.Pointer, n uintptr)
-
-// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't
-// define it because go_generics can't update the go:linkname annotation.
-// Furthermore, go:linkname silently doesn't work if the local name is exported
-// (this is of course undocumented), which is why this indirection is
-// necessary.
-func Memmove(to, from unsafe.Pointer, n uintptr) {
- memmove(to, from, n)
-}
diff --git a/pkg/sync/mutex_test.go b/pkg/sync/mutex_test.go
index 0838248b4..4fb51a8ab 100644
--- a/pkg/sync/mutex_test.go
+++ b/pkg/sync/mutex_test.go
@@ -32,11 +32,11 @@ func TestStructSize(t *testing.T) {
func TestFieldValues(t *testing.T) {
var m Mutex
m.Lock()
- if got := *m.state(); got != mutexLocked {
+ if got := *m.m.state(); got != mutexLocked {
t.Errorf("got locked sync.Mutex.state = %d, want = %d", got, mutexLocked)
}
m.Unlock()
- if got := *m.state(); got != mutexUnlocked {
+ if got := *m.m.state(); got != mutexUnlocked {
t.Errorf("got unlocked sync.Mutex.state = %d, want = %d", got, mutexUnlocked)
}
}
diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go
index f4c2e9642..21084b857 100644
--- a/pkg/sync/mutex_unsafe.go
+++ b/pkg/sync/mutex_unsafe.go
@@ -17,8 +17,9 @@ import (
"unsafe"
)
-// Mutex is a try lock.
-type Mutex struct {
+// CrossGoroutineMutex is equivalent to Mutex, but it need not be unlocked by a
+// the same goroutine that locked the mutex.
+type CrossGoroutineMutex struct {
sync.Mutex
}
@@ -27,7 +28,7 @@ type syncMutex struct {
sema uint32
}
-func (m *Mutex) state() *int32 {
+func (m *CrossGoroutineMutex) state() *int32 {
return &(*syncMutex)(unsafe.Pointer(&m.Mutex)).state
}
@@ -36,9 +37,9 @@ const (
mutexLocked = 1
)
-// TryLock tries to aquire the mutex. It returns true if it succeeds and false
+// TryLock tries to acquire the mutex. It returns true if it succeeds and false
// otherwise. TryLock does not block.
-func (m *Mutex) TryLock() bool {
+func (m *CrossGoroutineMutex) TryLock() bool {
if atomic.CompareAndSwapInt32(m.state(), mutexUnlocked, mutexLocked) {
if RaceEnabled {
RaceAcquire(unsafe.Pointer(&m.Mutex))
@@ -47,3 +48,43 @@ func (m *Mutex) TryLock() bool {
}
return false
}
+
+// Mutex is a mutual exclusion lock. The zero value for a Mutex is an unlocked
+// mutex.
+//
+// A Mutex must not be copied after first use.
+//
+// A Mutex must be unlocked by the same goroutine that locked it. This
+// invariant is enforced with the 'checklocks' build tag.
+type Mutex struct {
+ m CrossGoroutineMutex
+}
+
+// Lock locks m. If the lock is already in use, the calling goroutine blocks
+// until the mutex is available.
+func (m *Mutex) Lock() {
+ noteLock(unsafe.Pointer(m))
+ m.m.Lock()
+}
+
+// Unlock unlocks m.
+//
+// Preconditions:
+// * m is locked.
+// * m was locked by this goroutine.
+func (m *Mutex) Unlock() {
+ noteUnlock(unsafe.Pointer(m))
+ m.m.Unlock()
+}
+
+// TryLock tries to acquire the mutex. It returns true if it succeeds and false
+// otherwise. TryLock does not block.
+func (m *Mutex) TryLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(m))
+ locked := m.m.TryLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(m))
+ }
+ return locked
+}
diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go
index 006055dd6..70b5f3a5e 100644
--- a/pkg/sync/norace_unsafe.go
+++ b/pkg/sync/norace_unsafe.go
@@ -8,6 +8,7 @@
package sync
import (
+ "sync/atomic"
"unsafe"
)
@@ -33,3 +34,13 @@ func RaceRelease(addr unsafe.Pointer) {
// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge.
func RaceReleaseMerge(addr unsafe.Pointer) {
}
+
+// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to
+// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector.
+// This is necessary when implementing gopark callbacks, since no race context
+// is available during their execution.
+func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool {
+ // Use atomic.CompareAndSwapUintptr outside of race builds for
+ // inlinability.
+ return atomic.CompareAndSwapUintptr(ptr, old, new)
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/sync/race_amd64.s
index 5e216b045..57bc0ec79 100644
--- a/pkg/syncevent/waiter_amd64.s
+++ b/pkg/sync/race_amd64.s
@@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+// +build amd64
+
#include "textflag.h"
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-//
-// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
-TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
+TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25
MOVQ ptr+0(FP), DI
- MOVQ wg+8(FP), SI
+ MOVQ old+8(FP), AX
+ MOVQ new+16(FP), SI
- MOVQ $·preparingG(SB), AX
LOCK
- CMPXCHGQ DI, 0(SI)
+ CMPXCHGQ SI, 0(DI)
SETEQ AX
- MOVB AX, ret+16(FP)
+ MOVB AX, ret+24(FP)
RET
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/sync/race_arm64.s
index f4c06f194..88f091fda 100644
--- a/pkg/syncevent/waiter_arm64.s
+++ b/pkg/sync/race_arm64.s
@@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+// +build arm64
+
#include "textflag.h"
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-//
-// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
-TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
- MOVD wg+8(FP), R0
- MOVD $·preparingG(SB), R1
- MOVD ptr+0(FP), R2
+// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
+TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25
+ MOVD ptr+0(FP), R0
+ MOVD old+8(FP), R1
+ MOVD new+16(FP), R1
again:
LDAXR (R0), R3
CMP R1, R3
@@ -29,6 +30,6 @@ again:
CBNZ R3, again
ok:
CSET EQ, R0
- MOVB R0, ret+16(FP)
+ MOVB R0, ret+24(FP)
RET
diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go
index 31d8fa9a6..59985c270 100644
--- a/pkg/sync/race_unsafe.go
+++ b/pkg/sync/race_unsafe.go
@@ -39,3 +39,9 @@ func RaceRelease(addr unsafe.Pointer) {
func RaceReleaseMerge(addr unsafe.Pointer) {
runtime.RaceReleaseMerge(addr)
}
+
+// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to
+// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector.
+// This is necessary when implementing gopark callbacks, since no race context
+// is available during their execution.
+func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
new file mode 100644
index 000000000..e925e2e5b
--- /dev/null
+++ b/pkg/sync/runtime_unsafe.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.17
+
+// Check function signatures and constants when updating Go version.
+
+package sync
+
+import (
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// Note that go:linkname silently doesn't work if the local name is exported,
+// necessitating an indirection for exported functions.
+
+// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
+//
+//go:nosplit
+func Memmove(to, from unsafe.Pointer, n uintptr) {
+ memmove(to, from, n)
+}
+
+//go:linkname memmove runtime.memmove
+//go:noescape
+func memmove(to, from unsafe.Pointer, n uintptr)
+
+// Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock);
+// if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g)
+// is called. unlockf and its callees must be nosplit and norace, since stack
+// splitting and race context are not available where it is called.
+//
+//go:nosplit
+func Gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int) {
+ gopark(unlockf, lock, reason, traceEv, traceskip)
+}
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+// Goready is runtime.goready.
+//
+//go:nosplit
+func Goready(gp uintptr, traceskip int) {
+ goready(gp, traceskip)
+}
+
+//go:linkname goready runtime.goready
+func goready(gp uintptr, traceskip int)
+
+// Values for the reason argument to gopark, from Go's src/runtime/runtime2.go.
+const (
+ WaitReasonSelect uint8 = 9
+)
+
+// Values for the traceEv argument to gopark, from Go's src/runtime/trace.go.
+const (
+ TraceEvGoBlockSelect byte = 24
+)
+
+// Rand32 returns a non-cryptographically-secure random uint32.
+func Rand32() uint32 {
+ return fastrand()
+}
+
+// Rand64 returns a non-cryptographically-secure random uint64.
+func Rand64() uint64 {
+ return uint64(fastrand())<<32 | uint64(fastrand())
+}
+
+//go:linkname fastrand runtime.fastrand
+func fastrand() uint32
+
+// RandUintptr returns a non-cryptographically-secure random uintptr.
+func RandUintptr() uintptr {
+ if unsafe.Sizeof(uintptr(0)) == 4 {
+ return uintptr(Rand32())
+ }
+ return uintptr(Rand64())
+}
+
+// MapKeyHasher returns a hash function for pointers of m's key type.
+//
+// Preconditions: m must be a map.
+func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr {
+ if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map {
+ panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp))
+ }
+ mtyp := *(**maptype)(unsafe.Pointer(&m))
+ return mtyp.hasher
+}
+
+type maptype struct {
+ size uintptr
+ ptrdata uintptr
+ hash uint32
+ tflag uint8
+ align uint8
+ fieldAlign uint8
+ kind uint8
+ equal func(unsafe.Pointer, unsafe.Pointer) bool
+ gcdata *byte
+ str int32
+ ptrToThis int32
+ key unsafe.Pointer
+ elem unsafe.Pointer
+ bucket unsafe.Pointer
+ hasher func(unsafe.Pointer, uintptr) uintptr
+ // more fields
+}
+
+// These functions are only used within the sync package.
+
+//go:linkname semacquire sync.runtime_Semacquire
+func semacquire(s *uint32)
+
+//go:linkname semrelease sync.runtime_Semrelease
+func semrelease(s *uint32, handoff bool, skipframes int)
+
+//go:linkname canSpin sync.runtime_canSpin
+func canSpin(i int) bool
+
+//go:linkname doSpin sync.runtime_doSpin
+func doSpin()
diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go
index ce667e825..5ca96d12b 100644
--- a/pkg/sync/rwmutex_test.go
+++ b/pkg/sync/rwmutex_test.go
@@ -102,7 +102,7 @@ func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone c
}
for i := 0; i < 100; i++ {
}
- n = atomic.AddInt32(activity, -1)
+ atomic.AddInt32(activity, -1)
rwm.RUnlock()
}
cdone <- true
diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index b3b4dee78..4cf3fcd6e 100644
--- a/pkg/sync/rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -3,11 +3,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build go1.13
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
// This is mostly copied from the standard library's sync/rwmutex.go.
//
// Happens-before relationships indicated to the race detector:
@@ -23,16 +18,15 @@ import (
"unsafe"
)
-//go:linkname runtimeSemacquire sync.runtime_Semacquire
-func runtimeSemacquire(s *uint32)
-
-//go:linkname runtimeSemrelease sync.runtime_Semrelease
-func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
-
-// RWMutex is identical to sync.RWMutex, but adds the DowngradeLock,
-// TryLock and TryRLock methods.
-type RWMutex struct {
- w Mutex // held if there are pending writers
+// CrossGoroutineRWMutex is equivalent to RWMutex, but it need not be unlocked
+// by a the same goroutine that locked the mutex.
+type CrossGoroutineRWMutex struct {
+ // w is held if there are pending writers
+ //
+ // We use CrossGoroutineMutex rather than Mutex because the lock
+ // annotation instrumentation in Mutex will trigger false positives in
+ // the race detector when called inside of RaceDisable.
+ w CrossGoroutineMutex
writerSem uint32 // semaphore for writers to wait for completing readers
readerSem uint32 // semaphore for readers to wait for completing writers
readerCount int32 // number of pending readers
@@ -43,7 +37,7 @@ const rwmutexMaxReaders = 1 << 30
// TryRLock locks rw for reading. It returns true if it succeeds and false
// otherwise. It does not block.
-func (rw *RWMutex) TryRLock() bool {
+func (rw *CrossGoroutineRWMutex) TryRLock() bool {
if RaceEnabled {
RaceDisable()
}
@@ -67,13 +61,17 @@ func (rw *RWMutex) TryRLock() bool {
}
// RLock locks rw for reading.
-func (rw *RWMutex) RLock() {
+//
+// It should not be used for recursive read locking; a blocked Lock call
+// excludes new readers from acquiring the lock. See the documentation on the
+// RWMutex type.
+func (rw *CrossGoroutineRWMutex) RLock() {
if RaceEnabled {
RaceDisable()
}
if atomic.AddInt32(&rw.readerCount, 1) < 0 {
// A writer is pending, wait for it.
- runtimeSemacquire(&rw.readerSem)
+ semacquire(&rw.readerSem)
}
if RaceEnabled {
RaceEnable()
@@ -82,7 +80,10 @@ func (rw *RWMutex) RLock() {
}
// RUnlock undoes a single RLock call.
-func (rw *RWMutex) RUnlock() {
+//
+// Preconditions:
+// * rw is locked for reading.
+func (rw *CrossGoroutineRWMutex) RUnlock() {
if RaceEnabled {
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
RaceDisable()
@@ -94,7 +95,7 @@ func (rw *RWMutex) RUnlock() {
// A writer is pending.
if atomic.AddInt32(&rw.readerWait, -1) == 0 {
// The last reader unblocks the writer.
- runtimeSemrelease(&rw.writerSem, false, 0)
+ semrelease(&rw.writerSem, false, 0)
}
}
if RaceEnabled {
@@ -104,7 +105,7 @@ func (rw *RWMutex) RUnlock() {
// TryLock locks rw for writing. It returns true if it succeeds and false
// otherwise. It does not block.
-func (rw *RWMutex) TryLock() bool {
+func (rw *CrossGoroutineRWMutex) TryLock() bool {
if RaceEnabled {
RaceDisable()
}
@@ -130,8 +131,9 @@ func (rw *RWMutex) TryLock() bool {
return true
}
-// Lock locks rw for writing.
-func (rw *RWMutex) Lock() {
+// Lock locks rw for writing. If the lock is already locked for reading or
+// writing, Lock blocks until the lock is available.
+func (rw *CrossGoroutineRWMutex) Lock() {
if RaceEnabled {
RaceDisable()
}
@@ -141,7 +143,7 @@ func (rw *RWMutex) Lock() {
r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders
// Wait for active readers.
if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 {
- runtimeSemacquire(&rw.writerSem)
+ semacquire(&rw.writerSem)
}
if RaceEnabled {
RaceEnable()
@@ -150,7 +152,10 @@ func (rw *RWMutex) Lock() {
}
// Unlock unlocks rw for writing.
-func (rw *RWMutex) Unlock() {
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *CrossGoroutineRWMutex) Unlock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.writerSem))
RaceRelease(unsafe.Pointer(&rw.readerSem))
@@ -163,7 +168,7 @@ func (rw *RWMutex) Unlock() {
}
// Unblock blocked readers, if any.
for i := 0; i < int(r); i++ {
- runtimeSemrelease(&rw.readerSem, false, 0)
+ semrelease(&rw.readerSem, false, 0)
}
// Allow other writers to proceed.
rw.w.Unlock()
@@ -173,7 +178,10 @@ func (rw *RWMutex) Unlock() {
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
-func (rw *RWMutex) DowngradeLock() {
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *CrossGoroutineRWMutex) DowngradeLock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.readerSem))
RaceDisable()
@@ -186,7 +194,7 @@ func (rw *RWMutex) DowngradeLock() {
// Unblock blocked readers, if any. Note that this loop starts as 1 since r
// includes this goroutine.
for i := 1; i < int(r); i++ {
- runtimeSemrelease(&rw.readerSem, false, 0)
+ semrelease(&rw.readerSem, false, 0)
}
// Allow other writers to proceed to rw.w.Lock(). Note that they will still
// block on rw.writerSem since at least this reader exists, such that
@@ -196,3 +204,91 @@ func (rw *RWMutex) DowngradeLock() {
RaceEnable()
}
}
+
+// A RWMutex is a reader/writer mutual exclusion lock. The lock can be held by
+// an arbitrary number of readers or a single writer. The zero value for a
+// RWMutex is an unlocked mutex.
+//
+// A RWMutex must not be copied after first use.
+//
+// If a goroutine holds a RWMutex for reading and another goroutine might call
+// Lock, no goroutine should expect to be able to acquire a read lock until the
+// initial read lock is released. In particular, this prohibits recursive read
+// locking. This is to ensure that the lock eventually becomes available; a
+// blocked Lock call excludes new readers from acquiring the lock.
+//
+// A Mutex must be unlocked by the same goroutine that locked it. This
+// invariant is enforced with the 'checklocks' build tag.
+type RWMutex struct {
+ m CrossGoroutineRWMutex
+}
+
+// TryRLock locks rw for reading. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryRLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(rw))
+ locked := rw.m.TryRLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(rw))
+ }
+ return locked
+}
+
+// RLock locks rw for reading.
+//
+// It should not be used for recursive read locking; a blocked Lock call
+// excludes new readers from acquiring the lock. See the documentation on the
+// RWMutex type.
+func (rw *RWMutex) RLock() {
+ noteLock(unsafe.Pointer(rw))
+ rw.m.RLock()
+}
+
+// RUnlock undoes a single RLock call.
+//
+// Preconditions:
+// * rw is locked for reading.
+// * rw was locked by this goroutine.
+func (rw *RWMutex) RUnlock() {
+ rw.m.RUnlock()
+ noteUnlock(unsafe.Pointer(rw))
+}
+
+// TryLock locks rw for writing. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(rw))
+ locked := rw.m.TryLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(rw))
+ }
+ return locked
+}
+
+// Lock locks rw for writing. If the lock is already locked for reading or
+// writing, Lock blocks until the lock is available.
+func (rw *RWMutex) Lock() {
+ noteLock(unsafe.Pointer(rw))
+ rw.m.Lock()
+}
+
+// Unlock unlocks rw for writing.
+//
+// Preconditions:
+// * rw is locked for writing.
+// * rw was locked by this goroutine.
+func (rw *RWMutex) Unlock() {
+ rw.m.Unlock()
+ noteUnlock(unsafe.Pointer(rw))
+}
+
+// DowngradeLock atomically unlocks rw for writing and locks it for reading.
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *RWMutex) DowngradeLock() {
+ // No note change for DowngradeLock.
+ rw.m.DowngradeLock()
+}
diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go
index 2c5d3df99..1f025f33c 100644
--- a/pkg/sync/seqcount.go
+++ b/pkg/sync/seqcount.go
@@ -6,8 +6,6 @@
package sync
import (
- "fmt"
- "reflect"
"sync/atomic"
)
@@ -27,9 +25,6 @@ import (
// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other
// operations to be made atomic with reads of SeqCount-protected data.
//
-// - SeqCount may be less flexible: as of this writing, SeqCount-protected data
-// cannot include pointers.
-//
// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected
// data require instantiating function templates using go_generics (see
// seqatomic.go).
@@ -128,32 +123,3 @@ func (s *SeqCount) EndWrite() {
panic("SeqCount.EndWrite outside writer critical section")
}
}
-
-// PointersInType returns a list of pointers reachable from values named
-// valName of the given type.
-//
-// PointersInType is not exhaustive, but it is guaranteed that if typ contains
-// at least one pointer, then PointersInTypeOf returns a non-empty list.
-func PointersInType(typ reflect.Type, valName string) []string {
- switch kind := typ.Kind(); kind {
- case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
- return nil
-
- case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer:
- return []string{valName}
-
- case reflect.Array:
- return PointersInType(typ.Elem(), valName+"[]")
-
- case reflect.Struct:
- var ptrs []string
- for i, n := 0, typ.NumField(); i < n; i++ {
- field := typ.Field(i)
- ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...)
- }
- return ptrs
-
- default:
- return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)}
- }
-}
diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go
index 6eb7b4b59..3f5592e3e 100644
--- a/pkg/sync/seqcount_test.go
+++ b/pkg/sync/seqcount_test.go
@@ -6,7 +6,6 @@
package sync
import (
- "reflect"
"testing"
"time"
)
@@ -99,55 +98,3 @@ func BenchmarkSeqCountReadUncontended(b *testing.B) {
}
})
}
-
-func TestPointersInType(t *testing.T) {
- for _, test := range []struct {
- name string // used for both test and value name
- val interface{}
- ptrs []string
- }{
- {
- name: "EmptyStruct",
- val: struct{}{},
- },
- {
- name: "Int",
- val: int(0),
- },
- {
- name: "MixedStruct",
- val: struct {
- b bool
- I int
- ExportedPtr *struct{}
- unexportedPtr *struct{}
- arr [2]int
- ptrArr [2]*int
- nestedStruct struct {
- nestedNonptr int
- nestedPtr *int
- }
- structArr [1]struct {
- nonptr int
- ptr *int
- }
- }{},
- ptrs: []string{
- "MixedStruct.ExportedPtr",
- "MixedStruct.unexportedPtr",
- "MixedStruct.ptrArr[]",
- "MixedStruct.nestedStruct.nestedPtr",
- "MixedStruct.structArr[].ptr",
- },
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- typ := reflect.TypeOf(test.val)
- ptrs := PointersInType(typ, test.name)
- t.Logf("Found pointers: %v", ptrs)
- if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) {
- t.Errorf("Got %v, wanted %v", ptrs, test.ptrs)
- }
- })
- }
-}
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
index 0500a22cf..42c553308 100644
--- a/pkg/syncevent/BUILD
+++ b/pkg/syncevent/BUILD
@@ -9,10 +9,6 @@ go_library(
"receiver.go",
"source.go",
"syncevent.go",
- "waiter_amd64.s",
- "waiter_arm64.s",
- "waiter_asm_unsafe.go",
- "waiter_noasm_unsafe.go",
"waiter_unsafe.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
deleted file mode 100644
index 0f74a689c..000000000
--- a/pkg/syncevent/waiter_noasm_unsafe.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// waiterUnlock is called from g0, so when the race detector is enabled,
-// waiterUnlock must be implemented in assembly since no race context is
-// available.
-//
-// +build !race
-// +build !amd64,!arm64
-
-package syncevent
-
-import (
- "sync/atomic"
- "unsafe"
-)
-
-// waiterUnlock is the "unlock function" passed to runtime.gopark by
-// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
-// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
-// should be aborted.
-//
-//go:nosplit
-func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool {
- // The only way this CAS can fail is if a call to Waiter.NotifyPending()
- // has replaced *wg with nil, in which case we should not sleep.
- return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), ptr)
-}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
index 518f18479..b6ed2852d 100644
--- a/pkg/syncevent/waiter_unsafe.go
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.11
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package syncevent
import (
@@ -26,17 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-//go:linkname gopark runtime.gopark
-func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
-
-//go:linkname goready runtime.goready
-func goready(g unsafe.Pointer, traceskip int)
-
-const (
- waitReasonSelect = 9 // Go: src/runtime/runtime2.go
- traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
-)
-
// Waiter allows a goroutine to block on pending events received by a Receiver.
//
// Waiter.Init() must be called before first use.
@@ -45,20 +29,19 @@ type Waiter struct {
// g is one of:
//
- // - nil: No goroutine is blocking in Wait.
+ // - 0: No goroutine is blocking in Wait.
//
- // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // - preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
// completed waiterUnlock(). Thus the wait can only be interrupted by
- // replacing the value of g with nil (the G may not be in state Gwaiting
- // yet, so we can't call goready.)
+ // replacing the value of g with 0 (the G may not be in state Gwaiting yet,
+ // so we can't call goready.)
//
// - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
// goroutine blocked in Wait, which can only be woken by calling goready.
- g unsafe.Pointer `state:"zerovalue"`
+ g uintptr `state:"zerovalue"`
}
-// Sentinel object for Waiter.g.
-var preparingG struct{}
+const preparingG = 1
// Init must be called before first use of w.
func (w *Waiter) Init() {
@@ -99,21 +82,29 @@ func (w *Waiter) WaitFor(es Set) Set {
}
// Indicate that we're preparing to go to sleep.
- atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+ atomic.StoreUintptr(&w.g, preparingG)
// If an event is pending, abort the sleep.
if p := w.r.Pending(); p&es != NoEvents {
- atomic.StorePointer(&w.g, nil)
+ atomic.StoreUintptr(&w.g, 0)
return p
}
// If w.g is still preparingG (i.e. w.NotifyPending() has not been
- // called or has not reached atomic.SwapPointer()), go to sleep until
+ // called or has not reached atomic.SwapUintptr()), go to sleep until
// w.NotifyPending() => goready().
- gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
}
}
+//go:norace
+//go:nosplit
+func waiterCommit(g uintptr, wg unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(wg), preparingG, g)
+}
+
// Ack marks the given events as not pending.
func (w *Waiter) Ack(es Set) {
w.r.Ack(es)
@@ -135,20 +126,20 @@ func (w *Waiter) WaitAndAckAll() Set {
for {
// Indicate that we're preparing to go to sleep.
- atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+ atomic.StoreUintptr(&w.g, preparingG)
// If an event is pending, abort the sleep.
if w.r.Pending() != NoEvents {
if p := w.r.PendingAndAckAll(); p != NoEvents {
- atomic.StorePointer(&w.g, nil)
+ atomic.StoreUintptr(&w.g, 0)
return p
}
}
// If w.g is still preparingG (i.e. w.NotifyPending() has not been
- // called or has not reached atomic.SwapPointer()), go to sleep until
+ // called or has not reached atomic.SwapUintptr()), go to sleep until
// w.NotifyPending() => goready().
- gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
// Check for pending events. We call PendingAndAckAll() directly now since
// we only expect to be woken after events become pending.
@@ -171,14 +162,14 @@ func (w *Waiter) NotifyPending() {
// goroutine. NotifyPending is called after w.r.Pending() is updated, so
// concurrent and future calls to w.Wait() will observe pending events and
// abort sleeping.
- if atomic.LoadPointer(&w.g) == nil {
+ if atomic.LoadUintptr(&w.g) == 0 {
return
}
// Wake a sleeping G, or prevent a G that is preparing to sleep from doing
// so. Swap is needed here to ensure that only one call to NotifyPending
// calls goready.
- if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
- goready(g, 0)
+ if g := atomic.SwapUintptr(&w.g, 0); g > preparingG {
+ sync.Goready(g, 0)
}
}
diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go
index fc6ef60a1..77faa3670 100644
--- a/pkg/syserr/host_linux.go
+++ b/pkg/syserr/host_linux.go
@@ -32,7 +32,7 @@ var linuxHostTranslations [maxErrno]linuxHostTranslation
// FromHost translates a syscall.Errno to a corresponding Error value.
func FromHost(err syscall.Errno) *Error {
- if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
+ if int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err))
}
return linuxHostTranslations[err].err
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 5ae10939d..77c3c110c 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -15,6 +15,8 @@
package syserr
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -48,45 +50,56 @@ var (
ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
)
-var netstackErrorTranslations = map[*tcpip.Error]*Error{
- tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
- tcpip.ErrUnknownNICID: ErrUnknownNICID,
- tcpip.ErrUnknownDevice: ErrUnknownDevice,
- tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
- tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
- tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
- tcpip.ErrNoRoute: ErrNoRoute,
- tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
- tcpip.ErrAlreadyBound: ErrAlreadyBound,
- tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
- tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
- tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
- tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
- tcpip.ErrPortInUse: ErrPortInUse,
- tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
- tcpip.ErrClosedForSend: ErrClosedForSend,
- tcpip.ErrClosedForReceive: ErrClosedForReceive,
- tcpip.ErrWouldBlock: ErrWouldBlock,
- tcpip.ErrConnectionRefused: ErrConnectionRefused,
- tcpip.ErrTimeout: ErrTimeout,
- tcpip.ErrAborted: ErrAborted,
- tcpip.ErrConnectStarted: ErrConnectStarted,
- tcpip.ErrDestinationRequired: ErrDestinationRequired,
- tcpip.ErrNotSupported: ErrNotSupported,
- tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
- tcpip.ErrNotConnected: ErrNotConnected,
- tcpip.ErrConnectionReset: ErrConnectionReset,
- tcpip.ErrConnectionAborted: ErrConnectionAborted,
- tcpip.ErrNoSuchFile: ErrNoSuchFile,
- tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
- tcpip.ErrNoLinkAddress: ErrHostDown,
- tcpip.ErrBadAddress: ErrBadAddress,
- tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
- tcpip.ErrMessageTooLong: ErrMessageTooLong,
- tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
- tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
- tcpip.ErrNotPermitted: ErrNotPermittedNet,
- tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported,
+var netstackErrorTranslations map[string]*Error
+
+func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) {
+ key := tcpipErr.String()
+ if _, ok := netstackErrorTranslations[key]; ok {
+ panic(fmt.Sprintf("duplicate error key: %s", key))
+ }
+ netstackErrorTranslations[key] = netstackErr
+}
+
+func init() {
+ netstackErrorTranslations = make(map[string]*Error)
+ addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol)
+ addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID)
+ addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice)
+ addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption)
+ addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID)
+ addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress)
+ addErrMapping(tcpip.ErrNoRoute, ErrNoRoute)
+ addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint)
+ addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound)
+ addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState)
+ addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting)
+ addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected)
+ addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable)
+ addErrMapping(tcpip.ErrPortInUse, ErrPortInUse)
+ addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress)
+ addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend)
+ addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive)
+ addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock)
+ addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused)
+ addErrMapping(tcpip.ErrTimeout, ErrTimeout)
+ addErrMapping(tcpip.ErrAborted, ErrAborted)
+ addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted)
+ addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired)
+ addErrMapping(tcpip.ErrNotSupported, ErrNotSupported)
+ addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported)
+ addErrMapping(tcpip.ErrNotConnected, ErrNotConnected)
+ addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset)
+ addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted)
+ addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile)
+ addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue)
+ addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown)
+ addErrMapping(tcpip.ErrBadAddress, ErrBadAddress)
+ addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable)
+ addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong)
+ addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace)
+ addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled)
+ addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet)
+ addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported)
}
// TranslateNetstackError converts an error from the tcpip package to a sentry
@@ -95,7 +108,7 @@ func TranslateNetstackError(err *tcpip.Error) *Error {
if err == nil {
return nil
}
- se, ok := netstackErrorTranslations[err]
+ se, ok := netstackErrorTranslations[err.String()]
if !ok {
panic("Unknown error: " + err.String())
}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 81f762e10..91971b687 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -20,6 +20,7 @@ import (
"encoding/binary"
"reflect"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
+ case *ipv6HeaderWithExtHdr:
+ v = ip.HopLimit()
+ default:
+ t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
}
if v != ttl {
t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
@@ -216,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker {
}
}
+// IPv4RouterAlert returns a checker that checks that the RouterAlert option is
+// set in an IPv4 packet.
+func IPv4RouterAlert() NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ iterator := ip.Options().MakeIterator()
+ for {
+ opt, done, err := iterator.Next()
+ if err != nil {
+ t.Fatalf("error acquiring next IPv4 option %s", err)
+ }
+ if done {
+ break
+ }
+ if opt.Type() != header.IPv4OptionRouterAlertType {
+ continue
+ }
+ want := [header.IPv4OptionRouterAlertLength]byte{
+ byte(header.IPv4OptionRouterAlertType),
+ header.IPv4OptionRouterAlertLength,
+ header.IPv4OptionRouterAlertValue,
+ header.IPv4OptionRouterAlertValue,
+ }
+ if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
+ t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
+ }
+ return
+ }
+ t.Errorf("failed to find router alert option in %v", ip.Options())
+ }
+}
+
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
+// field in ControlMessages.
+func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasOriginalDstAddress {
+ t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
+ } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
+ t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -1012,6 +1066,74 @@ func ICMPv6Payload(want []byte) TransportChecker {
}
}
+// MLD creates a checker that checks that the packet contains a valid MLD
+// message for type of mldType, with potentially additional checks specified by
+// checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// MLD message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.MessageBody()); got < minSize {
+ t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMaxRespDelay(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MaximumResponseDelay(); got != want {
+ t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// MLDMulticastAddress creates a checker that checks the Multicast Address
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMulticastAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MulticastAddress(); got != want {
+ t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
@@ -1031,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
last := h[len(h)-1]
icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.NDPPayload()); got < minSize {
+ if got := len(icmp.MessageBody()); got < minSize {
t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
}
@@ -1065,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
if got := ns.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
@@ -1094,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
@@ -1112,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.SolicitedFlag(); got != want {
t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
@@ -1183,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
ndpOptions(t, na.Options(), opts)
}
}
@@ -1198,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ndpOptions(t, ns.Options(), opts)
}
}
@@ -1223,7 +1345,261 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ rs := header.NDPRouterSolicit(icmp.MessageBody())
ndpOptions(t, rs.Options(), opts)
}
}
+
+// IGMP checks the validity and properties of the given IGMP packet. It is
+// expected to be used in conjunction with other IGMP transport checkers for
+// specific properties.
+func IGMP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
+ }
+
+ igmp := header.IGMP(last.Payload())
+ for _, f := range checkers {
+ f(t, igmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// IGMPType creates a checker that checks the IGMP Type field.
+func IGMPType(want header.IGMPType) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.Type(); got != want {
+ t.Errorf("got igmp.Type() = %d, want = %d", got, want)
+ }
+ }
+}
+
+// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
+func IGMPMaxRespTime(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.MaxRespTime(); got != want {
+ t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
+func IGMPGroupAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.GroupAddress(); got != want {
+ t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IPv6ExtHdrChecker is a function to check an extension header.
+type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
+
+// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
+func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Error("not a valid IPv6 packet")
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var rawPayloadHeader header.IPv6RawPayloadHeader
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
+ return
+ }
+ r, ok := h.(header.IPv6RawPayloadHeader)
+ if ok {
+ rawPayloadHeader = r
+ break
+ }
+ }
+
+ networkHeader := ipv6HeaderWithExtHdr{
+ IPv6: ipv6,
+ transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
+ payload: rawPayloadHeader.Buf.ToView(),
+ }
+
+ for _, checker := range checkers {
+ checker(t, []header.Network{&networkHeader})
+ }
+}
+
+// IPv6ExtHdr checks for the presence of extension headers.
+//
+// All the extension headers in headers will be checked exhaustively in the
+// order provided.
+func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
+ if !ok {
+ t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
+ buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
+ )
+
+ for _, check := range headers {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ return
+ }
+ check(t, h)
+ }
+ // Validate we consumed all headers.
+ //
+ // The next one over should be a raw payload and then iterator should
+ // terminate.
+ wantDone := false
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done != wantDone {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
+ return
+ }
+ if done {
+ break
+ }
+ if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
+ t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
+ continue
+ }
+ wantDone = true
+ }
+ }
+}
+
+var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
+
+// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
+// extension headers into consideration, which is not the case with vanilla
+// header.IPv6.
+type ipv6HeaderWithExtHdr struct {
+ header.IPv6
+ transport tcpip.TransportProtocolNumber
+ payload []byte
+}
+
+// TransportProtocol implements header.Network.
+func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
+ return h.transport
+}
+
+// Payload implements header.Network.
+func (h *ipv6HeaderWithExtHdr) Payload() []byte {
+ return h.payload
+}
+
+// IPv6ExtHdrOptionChecker is a function to check an extension header option.
+type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
+
+// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
+// extension header and validates the containing options with checkers.
+//
+// checkers must exhaustively contain all the expected options.
+func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
+ return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
+ t.Helper()
+
+ hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
+ if !ok {
+ t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
+ return
+ }
+ optionsIterator := hbh.Iter()
+ for _, f := range checkers {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ f(t, opt)
+ }
+ // Validate all options were consumed.
+ for {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if !done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ if done {
+ break
+ }
+ }
+ }
+}
+
+// IPv6RouterAlert validates that an extension header option is the RouterAlert
+// option and matches on its value.
+func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
+ if !ok {
+ t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
+ return
+ }
+ if routerAlert.Value != want {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index d87797617..0bdc12d53 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -11,11 +11,13 @@ go_library(
"gue.go",
"icmpv4.go",
"icmpv6.go",
+ "igmp.go",
"interfaces.go",
"ipv4.go",
"ipv6.go",
"ipv6_extension_headers.go",
"ipv6_fragment.go",
+ "mld.go",
"ndp_neighbor_advert.go",
"ndp_neighbor_solicit.go",
"ndp_options.go",
@@ -39,6 +41,8 @@ go_test(
size = "small",
srcs = [
"checksum_test.go",
+ "igmp_test.go",
+ "ipv4_test.go",
"ipv6_test.go",
"ipversion_test.go",
"tcp_test.go",
@@ -58,6 +62,7 @@ go_test(
srcs = [
"eth_test.go",
"ipv6_extension_headers_test.go",
+ "mld_test.go",
"ndp_test.go",
],
library = ":header",
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 4303fc5d5..2eef64b4d 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -115,6 +115,12 @@ const (
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
+
+ // Multicast Listener Discovery (MLD) messages, see RFC 2710.
+
+ ICMPv6MulticastListenerQuery ICMPv6Type = 130
+ ICMPv6MulticastListenerReport ICMPv6Type = 131
+ ICMPv6MulticastListenerDone ICMPv6Type = 132
)
// IsErrorType returns true if the receiver is an ICMP error type.
@@ -245,10 +251,9 @@ func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
-// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6
-// packet's message body as defined by RFC 4443 section 2.1; the portion of the
-// ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
-func (b ICMPv6) NDPPayload() []byte {
+// MessageBody returns the message body as defined by RFC 4443 section 2.1; the
+// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
+func (b ICMPv6) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go
new file mode 100644
index 000000000..5c5be1b9d
--- /dev/null
+++ b/pkg/tcpip/header/igmp.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// IGMP represents an IGMP header stored in a byte array.
+type IGMP []byte
+
+// IGMP implements `Transport`.
+var _ Transport = (*IGMP)(nil)
+
+const (
+ // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes,
+ // as per RFC 2236, Section 2, Page 2.
+ IGMPMinimumSize = 8
+
+ // IGMPQueryMinimumSize is the minimum size of a valid Membership Query
+ // Message in bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPQueryMinimumSize = 8
+
+ // IGMPReportMinimumSize is the minimum size of a valid Report Message in
+ // bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPReportMinimumSize = 8
+
+ // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message
+ // in bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPLeaveMessageMinimumSize = 8
+
+ // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page
+ // 3.
+ IGMPTTL = 1
+
+ // igmpTypeOffset defines the offset of the type field in an IGMP message.
+ igmpTypeOffset = 0
+
+ // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an
+ // IGMP message.
+ igmpMaxRespTimeOffset = 1
+
+ // igmpChecksumOffset defines the offset of the checksum field in an IGMP
+ // message.
+ igmpChecksumOffset = 2
+
+ // igmpGroupAddressOffset defines the offset of the Group Address field in an
+ // IGMP message.
+ igmpGroupAddressOffset = 4
+
+ // IGMPProtocolNumber is IGMP's transport protocol number.
+ IGMPProtocolNumber tcpip.TransportProtocolNumber = 2
+)
+
+// IGMPType is the IGMP type field as per RFC 2236.
+type IGMPType byte
+
+// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2.
+// Descriptions below come from there.
+const (
+ // IGMPMembershipQuery indicates that the message type is Membership Query.
+ // "There are two sub-types of Membership Query messages:
+ // - General Query, used to learn which groups have members on an
+ // attached network.
+ // - Group-Specific Query, used to learn if a particular group
+ // has any members on an attached network.
+ // These two messages are differentiated by the Group Address, as
+ // described in section 1.4 ."
+ IGMPMembershipQuery IGMPType = 0x11
+ // IGMPv1MembershipReport indicates that the message is a Membership Report
+ // generated by a host using the IGMPv1 protocol: "an additional type of
+ // message, for backwards-compatibility with IGMPv1"
+ IGMPv1MembershipReport IGMPType = 0x12
+ // IGMPv2MembershipReport indicates that the Message type is a Membership
+ // Report generated by a host using the IGMPv2 protocol.
+ IGMPv2MembershipReport IGMPType = 0x16
+ // IGMPLeaveGroup indicates that the message type is a Leave Group
+ // notification message.
+ IGMPLeaveGroup IGMPType = 0x17
+)
+
+// Type is the IGMP type field.
+func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) }
+
+// SetType sets the IGMP type field.
+func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) }
+
+// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership
+// Query messages, in other cases it is set to 0 by the sender and ignored by
+// the receiver.
+func (b IGMP) MaxRespTime() time.Duration {
+ // As per RFC 2236 section 2.2,
+ //
+ // The Max Response Time field is meaningful only in Membership Query
+ // messages, and specifies the maximum allowed time before sending a
+ // responding report in units of 1/10 second. In all other messages, it
+ // is set to zero by the sender and ignored by receivers.
+ return DecisecondToDuration(b[igmpMaxRespTimeOffset])
+}
+
+// SetMaxRespTime sets the MaxRespTimeField.
+func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m }
+
+// Checksum is the IGMP checksum field.
+func (b IGMP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[igmpChecksumOffset:])
+}
+
+// SetChecksum sets the IGMP checksum field.
+func (b IGMP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum)
+}
+
+// GroupAddress gets the Group Address field.
+func (b IGMP) GroupAddress() tcpip.Address {
+ return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize])
+}
+
+// SetGroupAddress sets the Group Address field.
+func (b IGMP) SetGroupAddress(address tcpip.Address) {
+ if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize))
+ }
+}
+
+// SourcePort implements Transport.SourcePort.
+func (IGMP) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (IGMP) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (IGMP) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (IGMP) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (IGMP) Payload() []byte {
+ return nil
+}
+
+// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP
+// header.
+func IGMPCalculateChecksum(h IGMP) uint16 {
+ // The header contains a checksum itself, set it aside to avoid checksumming
+ // the checksum and replace it afterwards.
+ existingXsum := h.Checksum()
+ h.SetChecksum(0)
+ xsum := ^Checksum(h, 0)
+ h.SetChecksum(existingXsum)
+ return xsum
+}
+
+// DecisecondToDuration converts a value representing deci-seconds to a
+// time.Duration.
+func DecisecondToDuration(ds uint8) time.Duration {
+ return time.Duration(ds) * time.Second / 10
+}
diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go
new file mode 100644
index 000000000..b6126d29a
--- /dev/null
+++ b/pkg/tcpip/header/igmp_test.go
@@ -0,0 +1,110 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// TestIGMPHeader tests the functions within header.igmp
+func TestIGMPHeader(t *testing.T) {
+ const maxRespTimeTenthSec = 0xF0
+ b := []byte{
+ 0x11, // IGMP Type, Membership Query
+ maxRespTimeTenthSec, // Maximum Response Time
+ 0xC0, 0xC0, // Checksum
+ 0x01, 0x02, 0x03, 0x04, // Group Address
+ }
+
+ igmpHeader := header.IGMP(b)
+
+ if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want {
+ t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want)
+ }
+
+ if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want {
+ t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
+ }
+
+ if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want {
+ t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want)
+ }
+
+ if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want {
+ t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want)
+ }
+
+ igmpType := header.IGMPv2MembershipReport
+ igmpHeader.SetType(igmpType)
+ if got := igmpHeader.Type(); got != igmpType {
+ t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType)
+ }
+ if got := header.IGMPType(b[0]); got != igmpType {
+ t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType)
+ }
+
+ respTime := byte(0x02)
+ igmpHeader.SetMaxRespTime(respTime)
+ if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want {
+ t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
+ }
+
+ checksum := uint16(0x0102)
+ igmpHeader.SetChecksum(checksum)
+ if got := igmpHeader.Checksum(); got != checksum {
+ t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum)
+ }
+
+ groupAddress := tcpip.Address("\x04\x03\x02\x01")
+ igmpHeader.SetGroupAddress(groupAddress)
+ if got := igmpHeader.GroupAddress(); got != groupAddress {
+ t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress)
+ }
+}
+
+// TestIGMPChecksum ensures that the checksum calculator produces the expected
+// checksum.
+func TestIGMPChecksum(t *testing.T) {
+ b := []byte{
+ 0x11, // IGMP Type, Membership Query
+ 0xF0, // Maximum Response Time
+ 0xC0, 0xC0, // Checksum
+ 0x01, 0x02, 0x03, 0x04, // Group Address
+ }
+
+ igmpHeader := header.IGMP(b)
+
+ // Calculate the initial checksum after setting the checksum temporarily to 0
+ // to avoid checksumming the checksum.
+ initialChecksum := igmpHeader.Checksum()
+ igmpHeader.SetChecksum(0)
+ checksum := ^header.Checksum(b, 0)
+ igmpHeader.SetChecksum(initialChecksum)
+
+ if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum {
+ t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum)
+ }
+}
+
+func TestDecisecondToDuration(t *testing.T) {
+ const valueInDeciseconds = 5
+ if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want {
+ t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 91fe7b6a5..e6103f4bc 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -100,7 +100,7 @@ type IPv4Fields struct {
//
// That leaves ten 32 bit (4 byte) fields for options. An attempt to encode
// more will fail.
- Options IPv4Options
+ Options IPv4OptionsSerializer
}
// IPv4 is an IPv4 header.
@@ -157,6 +157,9 @@ const (
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+ // IPv4AllRoutersGroup is a multicast address for all routers.
+ IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02"
+
// IPv4MinimumProcessableDatagramSize is the minimum size of an IP
// packet that every IPv4 capable host must be able to
// process/reassemble.
@@ -282,18 +285,17 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
-// IPv4Options is a buffer that holds all the raw IP options.
-type IPv4Options []byte
-
-// SizeWithPadding implements stack.NetOptions.
-// It reports the size to allocate for the Options. RFC 791 page 23 (end of
-// section 3.1) says of the padding at the end of the options:
+// padIPv4OptionsLength returns the total length for IPv4 options of length l
+// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
// header ends on a 32 bit boundary.
-func (o IPv4Options) SizeWithPadding() int {
- return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1)
+func padIPv4OptionsLength(length uint8) uint8 {
+ return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1)
}
+// IPv4Options is a buffer that holds all the raw IP options.
+type IPv4Options []byte
+
// Options returns a buffer holding the options.
func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
@@ -372,26 +374,16 @@ func (b IPv4) CalculateChecksum() uint16 {
func (b IPv4) Encode(i *IPv4Fields) {
// The size of the options defines the size of the whole header and thus the
// IHL field. Options are rare and this is a heavily used function so it is
- // worth a bit of optimisation here to keep the copy out of the fast path.
- hdrLen := IPv4MinimumSize
+ // worth a bit of optimisation here to keep the serializer out of the fast
+ // path.
+ hdrLen := uint8(IPv4MinimumSize)
if len(i.Options) != 0 {
- // SizeWithPadding is always >= len(i.Options).
- aLen := i.Options.SizeWithPadding()
- hdrLen += aLen
- if hdrLen > len(b) {
- panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen))
- }
- opts := b[options:]
- // This avoids bounds checks on the next line(s) which would happen even
- // if there's no work to do.
- if n := copy(opts, i.Options); n != aLen {
- padding := opts[n:][:aLen-n]
- for i := range padding {
- padding[i] = 0
- }
- }
+ hdrLen += i.Options.Serialize(b[options:])
+ }
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize))
}
- b.SetHeaderLength(uint8(hdrLen))
+ b.SetHeaderLength(hdrLen)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
@@ -471,6 +463,10 @@ const (
// options and may appear multiple times.
IPv4OptionNOPType IPv4OptionType = 1
+ // IPv4OptionRouterAlertType is the option type for the Router Alert option,
+ // defined in RFC 2113 Section 2.1.
+ IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80
+
// IPv4OptionRecordRouteType is used by each router on the path of the packet
// to record its path. It is carried over to an Echo Reply.
IPv4OptionRecordRouteType IPv4OptionType = 7
@@ -871,3 +867,162 @@ func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
// Contents implements IPv4Option.
func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
+
+// Router Alert option specific related constants.
+//
+// from RFC 2113 section 2.1:
+//
+// +--------+--------+--------+--------+
+// |10010100|00000100| 2 octet value |
+// +--------+--------+--------+--------+
+//
+// Type:
+// Copied flag: 1 (all fragments must carry the option)
+// Option class: 0 (control)
+// Option number: 20 (decimal)
+//
+// Length: 4
+//
+// Value: A two octet code with the following values:
+// 0 - Router shall examine packet
+// 1-65535 - Reserved
+const (
+ // IPv4OptionRouterAlertLength is the length of a Router Alert option.
+ IPv4OptionRouterAlertLength = 4
+
+ // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit
+ // payload of the router alert option.
+ IPv4OptionRouterAlertValue = 0
+
+ // iPv4OptionRouterAlertValueOffset is the offset for the value of a
+ // RouterAlert option.
+ iPv4OptionRouterAlertValueOffset = 2
+)
+
+// IPv4SerializableOption is an interface to represent serializable IPv4 option
+// types.
+type IPv4SerializableOption interface {
+ // optionType returns the type identifier of the option.
+ optionType() IPv4OptionType
+}
+
+// IPv4SerializableOptionPayload is an interface providing serialization of the
+// payload of an IPv4 option.
+type IPv4SerializableOptionPayload interface {
+ // length returns the size of the payload.
+ length() uint8
+
+ // serializeInto serializes the payload into the provided byte buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // Length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MUST panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto(buffer []byte) uint8
+}
+
+// IPv4OptionsSerializer is a serializer for IPv4 options.
+type IPv4OptionsSerializer []IPv4SerializableOption
+
+// Length returns the total number of bytes required to serialize the options.
+func (s IPv4OptionsSerializer) Length() uint8 {
+ var total uint8
+ for _, opt := range s {
+ total++
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Add 1 to reported length to account for the length byte.
+ total += 1 + withPayload.length()
+ }
+ }
+ return padIPv4OptionsLength(total)
+}
+
+// Serialize serializes the provided list of IPV4 options into b.
+//
+// Note, b must be of sufficient size to hold all the options in s. See
+// IPv4OptionsSerializer.Length for details on the getting the total size
+// of a serialized IPv4OptionsSerializer.
+//
+// Serialize panics if b is not of sufficient size to hold all the options in s.
+func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 {
+ var total uint8
+ for _, opt := range s {
+ ty := opt.optionType()
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Serialize first to reduce bounds checks.
+ l := 2 + withPayload.serializeInto(b[2:])
+ b[0] = byte(ty)
+ b[1] = l
+ b = b[l:]
+ total += l
+ continue
+ }
+ // Options without payload consist only of the type field.
+ //
+ // NB: Repeating code from the branch above is intentional to minimize
+ // bounds checks.
+ b[0] = byte(ty)
+ b = b[1:]
+ total++
+ }
+
+ // According to RFC 791:
+ //
+ // The internet header padding is used to ensure that the internet
+ // header ends on a 32 bit boundary. The padding is zero.
+ padded := padIPv4OptionsLength(total)
+ b = b[:padded-total]
+ for i := range b {
+ b[i] = 0
+ }
+ return padded
+}
+
+var _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil)
+var _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil)
+
+// IPv4SerializableRouterAlertOption provides serialization of the Router Alert
+// IPv4 option according to RFC 2113.
+type IPv4SerializableRouterAlertOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType {
+ return IPv4OptionRouterAlertType
+}
+
+// Length implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) length() uint8 {
+ return IPv4OptionRouterAlertLength - iPv4OptionRouterAlertValueOffset
+}
+
+// SerializeInto implements IPv4SerializableOption.
+func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 {
+ binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue)
+ return o.length()
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil)
+
+// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option.
+type IPv4SerializableNOPOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableNOPOption) optionType() IPv4OptionType {
+ return IPv4OptionNOPType
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil)
+
+// IPv4SerializableListEndOption provides serialization for the IPv4 List End
+// option.
+type IPv4SerializableListEndOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableListEndOption) optionType() IPv4OptionType {
+ return IPv4OptionListEndType
+}
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
new file mode 100644
index 000000000..6475cd694
--- /dev/null
+++ b/pkg/tcpip/header/ipv4_test.go
@@ -0,0 +1,179 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestIPv4OptionsSerializer(t *testing.T) {
+ optCases := []struct {
+ name string
+ option []header.IPv4SerializableOption
+ expect []byte
+ }{
+ {
+ name: "NOP",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ },
+ expect: []byte{1, 0, 0, 0},
+ },
+ {
+ name: "ListEnd",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableListEndOption{},
+ },
+ expect: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{148, 4, 0, 0},
+ }, {
+ name: "NOP and RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{1, 148, 4, 0, 0, 0, 0, 0},
+ },
+ }
+
+ for _, opt := range optCases {
+ t.Run(opt.name, func(t *testing.T) {
+ s := header.IPv4OptionsSerializer(opt.option)
+ l := s.Length()
+ if got := len(opt.expect); got != int(l) {
+ t.Fatalf("s.Length() = %d, want = %d", got, l)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with full bytes to ensure padding is being set
+ // correctly.
+ b[i] = 0xFF
+ }
+ if serializedLength := s.Serialize(b); serializedLength != l {
+ t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l)
+ }
+ if diff := cmp.Diff(opt.expect, b); diff != "" {
+ t.Errorf("mismatched serialized option (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
+// fields when options are supplied.
+func TestIPv4EncodeOptions(t *testing.T) {
+ tests := []struct {
+ name string
+ numberOfNops int
+ encodedOptions header.IPv4Options // reply should look like this
+ wantIHL int
+ }{
+ {
+ name: "valid no options",
+ wantIHL: header.IPv4MinimumSize,
+ },
+ {
+ name: "one byte options",
+ numberOfNops: 1,
+ encodedOptions: header.IPv4Options{1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "two byte options",
+ numberOfNops: 2,
+ encodedOptions: header.IPv4Options{1, 1, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "three byte options",
+ numberOfNops: 3,
+ encodedOptions: header.IPv4Options{1, 1, 1, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "four byte options",
+ numberOfNops: 4,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "five byte options",
+ numberOfNops: 5,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 8,
+ },
+ {
+ name: "thirty nine byte options",
+ numberOfNops: 39,
+ encodedOptions: header.IPv4Options{
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 0,
+ },
+ wantIHL: header.IPv4MinimumSize + 40,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops))
+ for i := range serializeOpts {
+ serializeOpts[i] = &header.IPv4SerializableNOPOption{}
+ }
+ paddedOptionLength := serializeOpts.Length()
+ ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength)
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength)
+ hdr := buffer.NewPrependable(int(totalLen))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ // To check the padding works, poison the last byte of the options space.
+ if paddedOptionLength != serializeOpts.Length() {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0xff
+ ip.SetHeaderLength(0)
+ }
+ ip.Encode(&header.IPv4Fields{
+ Options: serializeOpts,
+ })
+ options := ip.Options()
+ wantOptions := test.encodedOptions
+ if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
+ t.Errorf("got IHL of %d, want %d", got, want)
+ }
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(wantOptions) == 0 && len(options) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(wantOptions, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 55d09355a..d522e5f10 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -48,11 +48,13 @@ type IPv6Fields struct {
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
- // PayloadLength is the "payload length" field of an IPv6 packet.
+ // PayloadLength is the "payload length" field of an IPv6 packet, including
+ // the length of all extension headers.
PayloadLength uint16
- // NextHeader is the "next header" field of an IPv6 packet.
- NextHeader uint8
+ // TransportProtocol is the transport layer protocol number. Serialized in the
+ // last "next header" field of the IPv6 header + extension headers.
+ TransportProtocol tcpip.TransportProtocolNumber
// HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
@@ -62,6 +64,9 @@ type IPv6Fields struct {
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr tcpip.Address
+
+ // ExtensionHeaders are the extension headers following the IPv6 header.
+ ExtensionHeaders IPv6ExtHdrSerializer
}
// IPv6 represents an ipv6 header stored in a byte array.
@@ -253,12 +258,14 @@ func (IPv6) SetChecksum(uint16) {
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
+ extHdr := b[IPv6MinimumSize:]
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
- b[IPv6NextHeaderOffset] = i.NextHeader
b[hopLimit] = i.HopLimit
b.SetSourceAddress(i.SrcAddr)
b.SetDestinationAddress(i.DstAddr)
+ nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
+ b[IPv6NextHeaderOffset] = nextHeader
}
// IsValid performs basic validation on the packet.
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 583c2c5d3..f18981332 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -18,9 +18,12 @@ import (
"bufio"
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
+ "math"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -47,6 +50,11 @@ const (
// IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
// of an IPv6 payload, as per RFC 8200 section 4.7.
IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
+
+ // IPv6UnknownExtHdrIdentifier is reserved by IANA.
+ // https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header
+ // "254 Use for experimentation and testing [RFC3692][RFC4727]"
+ IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254
)
const (
@@ -70,8 +78,8 @@ const (
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
- // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
- // discard from the Fragment Offset.
+ // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
+ // Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetShift = 3
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
@@ -109,6 +117,37 @@ const (
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
)
+// padIPv6OptionsLength returns the total length for IPv6 options of length l
+// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
+func padIPv6OptionsLength(length int) int {
+ return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
+}
+
+// padIPv6Option fills b with the appropriate padding options depending on its
+// length.
+func padIPv6Option(b []byte) {
+ switch len(b) {
+ case 0: // No padding needed.
+ case 1: // Pad with Pad1.
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
+ default: // Pad with PadN.
+ s := b[ipv6ExtHdrOptionPayloadOffset:]
+ for i := range s {
+ s[i] = 0
+ }
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
+ b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
+ }
+}
+
+// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
+// serialize an option at headerOffset with alignment requirements
+// [align]n + alignOffset.
+func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
+ padLen := headerOffset - alignOffset
+ return ((padLen + align - 1) & ^(align - 1)) - padLen
+}
+
// IPv6PayloadHeader is implemented by the various headers that can be found
// in an IPv6 payload.
//
@@ -201,29 +240,55 @@ type IPv6ExtHdrOption interface {
isIPv6ExtHdrOption()
}
-// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
-type IPv6ExtHdrOptionIndentifier uint8
+// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIdentifier uint8
const (
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
- ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
// ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
- ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
+
+ // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
+ // Alert Hop by Hop option as defined in RFC 2711 section 2.1.
+ ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
+
+ // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
+ // option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionTypeOffset = 0
+
+ // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionLengthOffset = 1
+
+ // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionPayloadOffset = 2
)
+// ipv6UnknownActionFromIdentifier maps an extension header option's
+// identifier's high bits to the action to take when the identifier is unknown.
+func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
+// is malformed.
+var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
+
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
- Identifier IPv6ExtHdrOptionIndentifier
+ Identifier IPv6ExtHdrOptionIdentifier
Data []byte
}
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
- return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+ return ipv6UnknownActionFromIdentifier(o.Identifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
@@ -246,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// options buffer has been exhausted and we are done iterating.
return nil, true, nil
}
- id := IPv6ExtHdrOptionIndentifier(temp)
+ id := IPv6ExtHdrOptionIdentifier(temp)
// If the option identifier indicates the option is a Pad1 option, then we
// know the option does not have Length and Data fields. End processing of
@@ -289,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
continue
+ case ipv6RouterAlertHopByHopOptionIdentifier:
+ var routerAlertValue [ipv6RouterAlertPayloadLength]byte
+ if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil {
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ default:
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
+ }
+ } else if n != int(length) {
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ }
+ return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
bytes := make([]byte, length)
if n, err := io.ReadFull(&i.reader, bytes); err != nil {
@@ -452,9 +530,11 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Since we consume the iterator, we return the payload as is.
buf = i.payload
- // Mark i as done.
+ // Mark i as done, but keep track of where we were for error reporting.
*i = IPv6PayloadIterator{
nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ headerOffset: i.headerOffset,
+ nextOffset: i.nextOffset,
}
} else {
buf = i.payload.Clone(nil)
@@ -602,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
}
+
+// IPv6SerializableExtHdr provides serialization for IPv6 extension
+// headers.
+type IPv6SerializableExtHdr interface {
+ // identifier returns the assigned IPv6 header identifier for this extension
+ // header.
+ identifier() IPv6ExtensionHeaderIdentifier
+
+ // length returns the total serialized length in bytes of this extension
+ // header, including the common next header and length fields.
+ length() int
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer and with the provided nextHeader value.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto returns the number of bytes that was used to serialize the
+ // receiver. Implementers must only use the number of bytes required to
+ // serialize the receiver. Callers MAY provide a larger buffer than required
+ // to serialize into.
+ serializeInto(nextHeader uint8, b []byte) int
+}
+
+var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
+
+// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
+// options extension header.
+type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
+
+const (
+ // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
+ // in a hop by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrNextHeaderOffset = 0
+
+ // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
+ // by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrLengthOffset = 1
+
+ // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
+ // hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrOptionsOffset = 2
+
+ // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
+ // words in a hop by hop extension header's length field, as stated in RFC
+ // 8200 section 4.3:
+ // Length of the Hop-by-Hop Options header in 8-octet units,
+ // not including the first 8 octets.
+ ipv6HopByHopExtHdrUnaccountedLenWords = 1
+)
+
+// identifier implements IPv6SerializableExtHdr.
+func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6HopByHopOptionsExtHdrIdentifier
+}
+
+// length implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) length() int {
+ var total int
+ for _, opt := range h {
+ align, alignOffset := opt.alignment()
+ total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
+ total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
+ }
+ // Account for next header and total length fields and add padding.
+ return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
+ totalLength := ipv6HopByHopExtHdrOptionsOffset
+ for _, opt := range h {
+ // Calculate alignment requirements and pad buffer if necessary.
+ align, alignOffset := opt.alignment()
+ padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
+ if padLen != 0 {
+ padIPv6Option(optBuffer[:padLen])
+ totalLength += padLen
+ optBuffer = optBuffer[padLen:]
+ }
+
+ l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
+ optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
+ optBuffer[ipv6ExtHdrOptionLengthOffset] = l
+ l += ipv6ExtHdrOptionPayloadOffset
+ totalLength += int(l)
+ optBuffer = optBuffer[l:]
+ }
+ padded := padIPv6OptionsLength(totalLength)
+ if padded != totalLength {
+ padIPv6Option(optBuffer[:padded-totalLength])
+ totalLength = padded
+ }
+ wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
+ if wordsLen > math.MaxUint8 {
+ panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
+ }
+ b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
+ b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
+ return totalLength
+}
+
+// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
+type IPv6SerializableHopByHopOption interface {
+ // identifier returns the option identifier of this Hop by Hop option.
+ identifier() IPv6ExtHdrOptionIdentifier
+
+ // length returns the *payload* size of the option (not considering the type
+ // and length fields).
+ length() uint8
+
+ // alignment returns the alignment requirements from this option.
+ //
+ // Alignment requirements take the form [align]n + offset as specified in
+ // RFC 8200 section 4.2. The alignment requirement is on the offset between
+ // the option type byte and the start of the hop by hop header.
+ //
+ // align must be a power of 2.
+ alignment() (align int, offset int)
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) uint8
+}
+
+var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
+
+// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
+// RFC 2711 section 2.1.
+type IPv6RouterAlertOption struct {
+ Value IPv6RouterAlertValue
+}
+
+// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
+type IPv6RouterAlertValue uint16
+
+const (
+ // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
+ // Discovery message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertMLD IPv6RouterAlertValue = 0
+ // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
+ // defined in RFC 2711 section 2.1.
+ IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
+ // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
+ // Networks message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
+
+ // ipv6RouterAlertPayloadLength is the length of the Router Alert payload
+ // as defined in RFC 2711.
+ ipv6RouterAlertPayloadLength = 2
+
+ // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
+ // Router Alert option defined as 2n+0 in RFC 2711.
+ ipv6RouterAlertAlignmentRequirement = 2
+
+ // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
+ // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
+ // 2.1.
+ ipv6RouterAlertAlignmentOffsetRequirement = 0
+)
+
+// UnknownAction implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
+ return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
+ return ipv6RouterAlertHopByHopOptionIdentifier
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) length() uint8 {
+ return ipv6RouterAlertPayloadLength
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) alignment() (int, int) {
+ // From RFC 2711 section 2.1:
+ // Alignment requirement: 2n+0.
+ return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
+ binary.BigEndian.PutUint16(b, uint16(o.Value))
+ return ipv6RouterAlertPayloadLength
+}
+
+// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
+type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
+
+// Serialize serializes the provided list of IPv6 extension headers into b.
+//
+// Note, b must be of sufficient size to hold all the headers in s. See
+// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
+// serialized IPv6ExtHdrSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+//
+// Serialize takes the transportProtocol value to be used as the last extension
+// header's Next Header value and returns the header identifier of the first
+// serialized extension header and the total serialized length.
+func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
+ nextHeader := uint8(transportProtocol)
+ if len(s) == 0 {
+ return nextHeader, 0
+ }
+ var totalLength int
+ for i, h := range s[:len(s)-1] {
+ length := h.serializeInto(uint8(s[i+1].identifier()), b)
+ b = b[length:]
+ totalLength += length
+ }
+ totalLength += s[len(s)-1].serializeInto(nextHeader, b)
+ return uint8(s[0].identifier()), totalLength
+}
+
+// Length returns the total number of bytes required to serialize the extension
+// headers.
+func (s IPv6ExtHdrSerializer) Length() int {
+ var totalLength int
+ for _, h := range s {
+ totalLength += h.length()
+ }
+ return totalLength
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
index ab20c5f37..65adc6250 100644
--- a/pkg/tcpip/header/ipv6_extension_headers_test.go
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool
func TestIPv6UnknownExtHdrOption(t *testing.T) {
tests := []struct {
name string
- identifier IPv6ExtHdrOptionIndentifier
+ identifier IPv6ExtHdrOptionIdentifier
expectedUnknownAction IPv6OptionUnknownAction
}{
{
@@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
bytes: []byte{1, 3},
err: io.ErrUnexpectedEOF,
},
+ {
+ name: "Router alert without data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data and Pad1",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with extra data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with missing data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1},
+ err: io.ErrUnexpectedEOF,
+ },
}
check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
@@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) {
})
}
}
+
+var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil)
+
+// dummyHbHOptionSerializer provides a generic implementation of
+// IPv6SerializableHopByHopOption for use in tests.
+type dummyHbHOptionSerializer struct {
+ id IPv6ExtHdrOptionIdentifier
+ payload []byte
+ align int
+ alignOffset int
+}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) length() uint8 {
+ return uint8(len(s.payload))
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) alignment() (int, int) {
+ align := 1
+ if s.align != 0 {
+ align = s.align
+ }
+ return align, s.alignOffset
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 {
+ return uint8(copy(b, s.payload))
+}
+
+func TestIPv6HopByHopSerializer(t *testing.T) {
+ validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ dummy, ok := serializable.(*dummyHbHOptionSerializer)
+ if !ok {
+ t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable)
+ }
+ unknown, ok := deserialized.(*IPv6UnknownExtHdrOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{})
+ }
+ if dummy.id != unknown.Identifier {
+ t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id)
+ }
+ if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" {
+ t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff)
+ }
+ }
+ tests := []struct {
+ name string
+ nextHeader uint8
+ options []IPv6SerializableHopByHopOption
+ expect []byte
+ validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption)
+ }{
+ {
+ name: "single option",
+ nextHeader: 13,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 15,
+ payload: []byte{9, 8, 7, 6},
+ },
+ },
+ expect: []byte{13, 0, 15, 4, 9, 8, 7, 6},
+ validate: validateDummies,
+ },
+ {
+ name: "short option padN zero",
+ nextHeader: 88,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5},
+ },
+ },
+ expect: []byte{88, 0, 22, 2, 4, 5, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "short option pad1",
+ nextHeader: 11,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 33,
+ payload: []byte{1, 2, 3},
+ },
+ },
+ expect: []byte{11, 0, 33, 3, 1, 2, 3, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "long option padN",
+ nextHeader: 55,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 77,
+ payload: []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ },
+ },
+ expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 2n",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 2,
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 8n+1",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 8,
+ alignOffset: 1,
+ },
+ },
+ expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "no options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{},
+ expect: []byte{33, 0, 1, 4, 0, 0, 0, 0},
+ },
+ {
+ name: "Router Alert",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}},
+ expect: []byte{33, 0, 5, 2, 0, 0, 1, 0},
+ validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ routerAlert, ok := deserialized.(*IPv6RouterAlertOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized)
+ }
+ if routerAlert.Value != IPv6RouterAlertMLD {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD)
+ }
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6SerializableHopByHopExtHdr(test.options)
+ length := s.length()
+ if length != len(test.expect) {
+ t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect))
+ }
+ b := make([]byte, length)
+ for i := range b {
+ // Fill the buffer with ones to ensure all padding is correctly set.
+ b[i] = 0xFF
+ }
+ if got := s.serializeInto(test.nextHeader, b); got != length {
+ t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length)
+ }
+ if diff := cmp.Diff(test.expect, b); diff != "" {
+ t.Fatalf("serialization mismatch (-want +got):\n%s", diff)
+ }
+
+ // Deserialize the options and verify them.
+ optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit
+ iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter()
+ for _, testOpt := range test.options {
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ test.validate(t, testOpt, opt)
+ }
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if !done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ })
+ }
+}
+
+var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil)
+
+// dummyIPv6ExtHdrSerializer provides a generic implementation of
+// IPv6SerializableExtHdr for use in tests.
+//
+// The dummy header always carries the nextHeader value in the first byte.
+type dummyIPv6ExtHdrSerializer struct {
+ id IPv6ExtensionHeaderIdentifier
+ headerContents []byte
+}
+
+// identifier implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) length() int {
+ return len(s.headerContents) + 1
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int {
+ b[0] = nextHeader
+ return copy(b[1:], s.headerContents) + 1
+}
+
+func TestIPv6ExtHdrSerializer(t *testing.T) {
+ tests := []struct {
+ name string
+ headers []IPv6SerializableExtHdr
+ nextHeader tcpip.TransportProtocolNumber
+ expectSerialized []byte
+ expectNextHeader uint8
+ }{
+ {
+ name: "one header",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 15,
+ headerContents: []byte{1, 2, 3, 4},
+ },
+ },
+ nextHeader: TCPProtocolNumber,
+ expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4},
+ expectNextHeader: 15,
+ },
+ {
+ name: "two headers",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 22,
+ headerContents: []byte{1, 2, 3},
+ },
+ &dummyIPv6ExtHdrSerializer{
+ id: 23,
+ headerContents: []byte{4, 5, 6},
+ },
+ },
+ nextHeader: ICMPv6ProtocolNumber,
+ expectSerialized: []byte{
+ 23, 1, 2, 3,
+ byte(ICMPv6ProtocolNumber), 4, 5, 6,
+ },
+ expectNextHeader: 22,
+ },
+ {
+ name: "no headers",
+ headers: []IPv6SerializableExtHdr{},
+ nextHeader: UDPProtocolNumber,
+ expectSerialized: []byte{},
+ expectNextHeader: byte(UDPProtocolNumber),
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6ExtHdrSerializer(test.headers)
+ l := s.Length()
+ if got, want := l, len(test.expectSerialized); got != want {
+ t.Fatalf("got serialized length = %d, want = %d", got, want)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with garbage to make sure we're writing to all bytes.
+ b[i] = 0xFF
+ }
+ nextHeader, serializedLen := s.Serialize(test.nextHeader, b)
+ if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader {
+ t.Errorf(
+ "got s.Serialize(..) = (%d, %d), want = (%d, %d)",
+ nextHeader,
+ serializedLen,
+ test.expectNextHeader,
+ len(test.expectSerialized),
+ )
+ }
+ if diff := cmp.Diff(test.expectSerialized, b); diff != "" {
+ t.Errorf("serialization mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
index 018555a26..9d09f32eb 100644
--- a/pkg/tcpip/header/ipv6_fragment.go
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -27,12 +27,11 @@ const (
idV6 = 4
)
-// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
-// fields of a packet that needs to be encoded.
-type IPv6FragmentFields struct {
- // NextHeader is the "next header" field of an IPv6 fragment.
- NextHeader uint8
+var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
+// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
+// extension header as defined in RFC 8200 section 4.5.
+type IPv6SerializableFragmentExtHdr struct {
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
@@ -43,6 +42,29 @@ type IPv6FragmentFields struct {
Identification uint32
}
+// identifier implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6FragmentHeader
+}
+
+// length implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) length() int {
+ return IPv6FragmentHeaderSize
+}
+
+// serializeInto implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ // Prevent too many bounds checks.
+ _ = b[IPv6FragmentHeaderSize:]
+ binary.BigEndian.PutUint32(b[idV6:], h.Identification)
+ binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
+ b[nextHdrFrag] = nextHeader
+ if h.M {
+ b[more] |= ipv6FragmentExtHdrMFlagMask
+ }
+ return IPv6FragmentHeaderSize
+}
+
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
@@ -58,16 +80,6 @@ const (
IPv6FragmentHeaderSize = 8
)
-// Encode encodes all the fields of the ipv6 fragment.
-func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
- b[nextHdrFrag] = i.NextHeader
- binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
- if i.M {
- b[more] |= 1
- }
- binary.BigEndian.PutUint32(b[idV6:], i.Identification)
-}
-
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
diff --git a/pkg/tcpip/header/mld.go b/pkg/tcpip/header/mld.go
new file mode 100644
index 000000000..ffe03c76a
--- /dev/null
+++ b/pkg/tcpip/header/mld.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // MLDMinimumSize is the minimum size for an MLD message.
+ MLDMinimumSize = 20
+
+ // MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as
+ // per RFC 2710 section 3.
+ MLDHopLimit = 1
+
+ // mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay
+ // field within MLD.
+ mldMaximumResponseDelayOffset = 0
+
+ // mldMulticastAddressOffset is the offset to the Multicast Address field
+ // within MLD.
+ mldMulticastAddressOffset = 4
+)
+
+// MLD is a Multicast Listener Discovery message in an ICMPv6 packet.
+//
+// MLD will only contain the body of an ICMPv6 packet.
+//
+// As per RFC 2710 section 3, MLD messages have the following format (MLD only
+// holds the bytes after the first four bytes in the diagram below):
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Code | Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Maximum Response Delay | Reserved |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | |
+// + +
+// | |
+// + Multicast Address +
+// | |
+// + +
+// | |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+type MLD []byte
+
+// MaximumResponseDelay returns the Maximum Response Delay.
+func (m MLD) MaximumResponseDelay() time.Duration {
+ // As per RFC 2710 section 3.4:
+ //
+ // The Maximum Response Delay field is meaningful only in Query
+ // messages, and specifies the maximum allowed delay before sending a
+ // responding Report, in units of milliseconds. In all other messages,
+ // it is set to zero by the sender and ignored by receivers.
+ return time.Duration(binary.BigEndian.Uint16(m[mldMaximumResponseDelayOffset:])) * time.Millisecond
+}
+
+// SetMaximumResponseDelay sets the Maximum Response Delay field.
+//
+// maxRespDelayMS is the value in milliseconds.
+func (m MLD) SetMaximumResponseDelay(maxRespDelayMS uint16) {
+ binary.BigEndian.PutUint16(m[mldMaximumResponseDelayOffset:], maxRespDelayMS)
+}
+
+// MulticastAddress returns the Multicast Address.
+func (m MLD) MulticastAddress() tcpip.Address {
+ // As per RFC 2710 section 3.5:
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ //
+ // In a Report or Done message, the Multicast Address field holds a
+ // specific IPv6 multicast address to which the message sender is
+ // listening or is ceasing to listen, respectively.
+ return tcpip.Address(m[mldMulticastAddressOffset:][:IPv6AddressSize])
+}
+
+// SetMulticastAddress sets the Multicast Address field.
+func (m MLD) SetMulticastAddress(multicastAddress tcpip.Address) {
+ if n := copy(m[mldMulticastAddressOffset:], multicastAddress); n != IPv6AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, IPv6AddressSize))
+ }
+}
diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go
new file mode 100644
index 000000000..0cecf10d4
--- /dev/null
+++ b/pkg/tcpip/header/mld_test.go
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestMLD(t *testing.T) {
+ b := []byte{
+ // Maximum Response Delay
+ 0, 0,
+
+ // Reserved
+ 0, 0,
+
+ // MulticastAddress
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6,
+ }
+
+ const maxRespDelay = 513
+ binary.BigEndian.PutUint16(b, maxRespDelay)
+
+ mld := MLD(b)
+
+ if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want {
+ t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
+ }
+
+ const newMaxRespDelay = 1234
+ mld.SetMaximumResponseDelay(newMaxRespDelay)
+ if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want {
+ t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
+ }
+
+ if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want {
+ t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want)
+ }
+
+ multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})
+ mld.SetMulticastAddress(multicastAddress)
+ if got := mld.MulticastAddress(); got != multicastAddress {
+ t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress)
+ }
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 5d3975c56..554242f0c 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -298,7 +298,7 @@ func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
return it, nil
}
-// Serialize serializes the provided list of NDP options into o.
+// Serialize serializes the provided list of NDP options into b.
//
// Note, b must be of sufficient size to hold all the options in s. See
// NDPOptionsSerializer.Length for details on the getting the total size
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 98bdd29db..a6d4fcd59 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -36,10 +36,10 @@ const (
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
- // SrcPort is the "source port" field of a UDP packet.
+ // SrcPort is the "Source Port" field of a UDP packet.
SrcPort uint16
- // DstPort is the "destination port" field of a UDP packet.
+ // DstPort is the "Destination Port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
@@ -64,52 +64,57 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
-// SourcePort returns the "source port" field of the udp header.
+// SourcePort returns the "Source Port" field of the UDP header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
-// DestinationPort returns the "destination port" field of the udp header.
+// DestinationPort returns the "Destination Port" field of the UDP header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
-// Length returns the "length" field of the udp header.
+// Length returns the "Length" field of the UDP header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
- return b[UDPMinimumSize:]
+ return b[:b.Length()][UDPMinimumSize:]
}
-// Checksum returns the "checksum" field of the udp header.
+// Checksum returns the "checksum" field of the UDP header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
-// SetSourcePort sets the "source port" field of the udp header.
+// SetSourcePort sets the "source port" field of the UDP header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
-// SetDestinationPort sets the "destination port" field of the udp header.
+// SetDestinationPort sets the "destination port" field of the UDP header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
-// SetChecksum sets the "checksum" field of the udp header.
+// SetChecksum sets the "checksum" field of the UDP header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// SetLength sets the "length" field of the udp header.
+// SetLength sets the "length" field of the UDP header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
-// CalculateChecksum calculates the checksum of the udp packet, given the
+// PayloadLength returns the length of the payload following the UDP header.
+func (b UDP) PayloadLength() uint16 {
+ return b.Length() - UDPMinimumSize
+}
+
+// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 39ca774ef..973f06cbc 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -9,7 +9,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c95aef63c..0efbfb22b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -32,7 +31,7 @@ type PacketInfo struct {
Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
- Route stack.Route
+ Route *stack.Route
}
// Notification is the interface for receiving notification from the packet
@@ -271,21 +270,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := PacketInfo{
- Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }),
- Proto: 0,
- GSO: nil,
- }
-
- e.q.Write(p)
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index 3eef7cd56..beefcd008 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -62,7 +62,7 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt)
return e.Endpoint.WritePacket(r, gso, proto, pkt)
}
@@ -71,7 +71,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkAddr := e.Endpoint.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 975309fc8..cb94cbea6 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // enable PACKET_FANOUT mode is the underlying socket is
- // of type AF_PACKET.
- const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ // Enable PACKET_FANOUT mode if the underlying socket is of type
+ // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
+ // prevent gvisor from receiving fragmented packets and the host does the
+ // reassembly on our behalf before delivering the fragments. This makes it
+ // hard to test fragmentation reassembly code in Netstack.
+ const fanoutType = unix.PACKET_FANOUT_HASH
fanoutArg := fanoutID | fanoutType<<16
if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
@@ -410,7 +413,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
}
var builder iovec.Builder
@@ -453,7 +456,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress(), pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
@@ -558,11 +561,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool {
return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0])
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return rawfile.NonBlockingWrite(e.fds[0], vv.ToView())
-}
-
// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 709f829c8..ce4da7230 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -183,9 +183,8 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
defer c.cleanup()
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- }
+ var r stack.Route
+ r.ResolveWith(raddr)
// Build payload.
payload := buffer.NewView(plen)
@@ -220,7 +219,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(&r, gso, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -325,9 +324,9 @@ func TestPreserveSrcAddress(t *testing.T) {
// Set LocalLinkAddress in route to the value of the bridged address.
r := &stack.Route{
- RemoteLinkAddress: raddr,
- LocalLinkAddress: baddr,
+ LocalLinkAddress: baddr,
}
+ r.ResolveWith(raddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 38aa694e4..edca57e4e 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- // There should be an ethernet header at the beginning of vv.
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- // Reject the packet if it's shorter than an ethernet header.
- return tcpip.ErrBadAddress
- }
- linkHeader := header.Ethernet(hdr)
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt)
-
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index e7493e5c5..cbda59775 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 56a611825..22e79ce3a 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -17,7 +17,6 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco
return tcpip.ErrNoRoute
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- // WriteRawPacket doesn't get a route or network address, so there's
- // nowhere to write this.
- return tcpip.ErrNoRoute
-}
-
// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
index 2cdb23475..00b42b924 100644
--- a/pkg/tcpip/link/nested/BUILD
+++ b/pkg/tcpip/link/nested/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index d40de54df..0ee54c3d5 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -19,7 +19,6 @@ package nested
import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return e.child.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return e.child.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.
func (e *Endpoint) Wait() {
e.child.Wait()
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
index 3922c2a04..9a1b0c0c2 100644
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -36,14 +36,14 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
// WritePacket implements stack.LinkEndpoint.WritePacket.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress(), r.LocalLinkAddress, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress(), pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 523b0d24b..25c364391 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -55,7 +55,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network
// remote address from the perspective of the other end of the pipe
// (e.linked). Similarly, the remote address from the perspective of this
// endpoint is the local address on the other end.
- e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
}))
@@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList,
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- panic("not implemented")
-}
-
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 1d0079bd6..5bea598eb 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -13,7 +13,6 @@ go_library(
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index fc1e34fc7..27667f5f0 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -156,7 +155,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
newRoute := r.Clone()
- pkt.EgressRoute = &newRoute
+ pkt.EgressRoute = newRoute
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
@@ -183,7 +182,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// the route here to ensure it doesn't get released while the
// packet is still in our queue.
newRoute := pkt.EgressRoute.Clone()
- pkt.EgressRoute = &newRoute
+ pkt.EgressRoute = newRoute
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
@@ -197,13 +196,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
return enqueued, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // TODO(gvisor.dev/issue/3267): Queue these packets as well once
- // WriteRawPacket takes PacketBuffer instead of VectorisedView.
- return e.lower.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() {
e.lower.Wait()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 7fb8a6c49..5660418fa 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -204,7 +204,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
views := pkt.Views()
// Transmit the packet.
@@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- views := vv.Views()
- // Transmit the packet.
- e.mu.Lock()
- ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
- if !ok {
- return tcpip.ErrWouldBlock
- }
-
- return nil
-}
-
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
// to the network stack.
func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 22d5c97f1..7131392cc 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -260,9 +260,8 @@ func TestSimpleSend(t *testing.T) {
defer c.cleanup()
// Prepare route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
for iters := 1000; iters > 0; iters-- {
func() {
@@ -342,9 +341,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
// Set both remote and local link address in route.
r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- LocalLinkAddress: newLocalLinkAddress,
+ LocalLinkAddress: newLocalLinkAddress,
}
+ r.ResolveWith(remoteLinkAddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
@@ -395,9 +394,8 @@ func TestFillTxQueue(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -444,9 +442,8 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
c.txq.rx.Flush()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -509,9 +506,8 @@ func TestFillTxMemory(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -557,9 +553,8 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index b3e8c4b92..8d9a91020 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -53,16 +53,35 @@ type endpoint struct {
nested.Endpoint
writer io.Writer
maxPCAPLen uint32
+ logPrefix string
}
var _ stack.GSOEndpoint = (*endpoint)(nil)
var _ stack.LinkEndpoint = (*endpoint)(nil)
var _ stack.NetworkDispatcher = (*endpoint)(nil)
+type direction int
+
+const (
+ directionSend = iota
+ directionRecv
+)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- sniffer := &endpoint{}
+ return NewWithPrefix(lower, "")
+}
+
+// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets prefixed with logPrefix as they traverse
+// the endpoint.
+//
+// logPrefix is prepended to the log line without any separators.
+// E.g. logPrefix = "NIC:en0/" will produce log lines like
+// "NIC:en0/send udp [...]".
+func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint {
+ sniffer := &endpoint{logPrefix: logPrefix}
sniffer.Endpoint.Init(lower, sniffer)
return sniffer
}
@@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dumpPacket("recv", nil, protocol, pkt)
+ e.dumpPacket(directionRecv, nil, protocol, pkt)
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
-func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- logPacket(prefix, protocol, pkt, gso)
+ logPacket(e.logPrefix, dir, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Size()
@@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// forwards the request to the lower endpoint.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return e.Endpoint.WriteRawPacket(vv)
-}
-
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -201,6 +212,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var fragmentOffset uint16
var moreFragments bool
+ var directionPrefix string
+ switch dir {
+ case directionSend:
+ directionPrefix = "send"
+ case directionRecv:
+ directionPrefix = "recv"
+ default:
+ panic(fmt.Sprintf("unrecognized direction: %d", dir))
+ }
+
// Clone the packet buffer to not modify the original.
//
// We don't clone the original packet buffer so that the new packet buffer
@@ -248,15 +269,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
- "%s arp %s (%s) -> %s (%s) valid:%t",
+ "%s%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
+ directionPrefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
arp.IsValid(),
)
return
default:
- log.Infof("%s unknown network protocol", prefix)
+ log.Infof("%s%s unknown network protocol", prefix, directionPrefix)
return
}
@@ -300,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
icmpType = "info reply"
}
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
@@ -335,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
@@ -391,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
}
default:
- log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto)
return
}
@@ -399,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 9a76bdba7..a364c5801 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
// If the packet does not already have link layer header, and the route
// does not exist, we can't compute it. This is possibly a raw packet, tun
// device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" {
+ if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" {
return nil, false
}
@@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader().View().IsEmpty() {
- d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt)
}
vv.AppendView(info.Pkt.LinkHeader().View())
}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index ee84c3d96..9b4602c1b 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/gate",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
@@ -25,7 +24,6 @@ go_test(
library = ":waitable",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index b152a0f26..cf0077f43 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -24,7 +24,6 @@ package waitable
import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, err
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if !e.writeGate.Enter() {
- return nil
- }
-
- err := e.lower.WriteRawPacket(vv)
- e.writeGate.Leave()
- return err
-}
-
// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
// and waits for inflight ones to finish before returning.
func (e *Endpoint) WaitWrite() {
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 94827fc56..cf7fb5126 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -18,7 +18,6 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.
return pkts.Len(), nil
}
-func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- e.writeCount++
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index b38aff0b8..9ebf31b78 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -7,12 +7,14 @@ go_test(
size = "small",
srcs = [
"ip_test.go",
+ "multicast_group_test.go",
],
deps = [
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index f462524c9..0fb373612 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -319,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
copy(h.HardwareAddressSender(), test.senderLinkAddr)
copy(h.ProtocolAddressSender(), test.senderAddr)
copy(h.ProtocolAddressTarget(), test.targetAddr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
- })
+ }))
if !test.isValid {
// No packets should be sent after receiving an invalid ARP request.
@@ -442,9 +442,9 @@ func (*testInterface) Promiscuous() bool {
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -557,8 +557,8 @@ func TestLinkAddressRequest(t *testing.T) {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
}
rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index d8e4a3b54..429af69ee 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -18,7 +18,6 @@ go_template_instance(
go_library(
name = "fragmentation",
srcs = [
- "frag_heap.go",
"fragmentation.go",
"reassembler.go",
"reassembler_list.go",
@@ -38,7 +37,6 @@ go_test(
name = "fragmentation_test",
size = "small",
srcs = [
- "frag_heap_test.go",
"fragmentation_test.go",
"reassembler_test.go",
],
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
deleted file mode 100644
index 0b570d25a..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-type fragment struct {
- offset uint16
- vv buffer.VectorisedView
-}
-
-type fragHeap []fragment
-
-func (h *fragHeap) Len() int {
- return len(*h)
-}
-
-func (h *fragHeap) Less(i, j int) bool {
- return (*h)[i].offset < (*h)[j].offset
-}
-
-func (h *fragHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *fragHeap) Push(x interface{}) {
- *h = append(*h, x.(fragment))
-}
-
-func (h *fragHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[:n-1]
- return x
-}
-
-// reassamble empties the heap and returns a VectorisedView
-// containing a reassambled version of the fragments inside the heap.
-func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
- curr := heap.Pop(h).(fragment)
- views := curr.vv.Views()
- size := curr.vv.Size()
-
- if curr.offset != 0 {
- return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
- }
-
- for h.Len() > 0 {
- curr := heap.Pop(h).(fragment)
- if int(curr.offset) < size {
- curr.vv.TrimFront(size - int(curr.offset))
- } else if int(curr.offset) > size {
- return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
- }
- size += curr.vv.Size()
- views = append(views, curr.vv.Views()...)
- }
- return buffer.NewVectorisedView(size, views), nil
-}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
deleted file mode 100644
index 9ececcb9f..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap_test.go
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-var reassambleTestCases = []struct {
- comment string
- in []fragment
- want buffer.VectorisedView
-}{
- {
- comment: "Non-overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Non-overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Duplicated packets",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(1, "0"),
- },
- {
- comment: "Overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(2, "01")},
- {offset: 1, vv: vv(2, "12")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(2, "12")},
- {offset: 0, vv: vv(2, "01")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping subset in-order",
- in: []fragment{
- {offset: 0, vv: vv(3, "012")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(3, "012"),
- },
- {
- comment: "Overlapping subset out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(3, "012")},
- },
- want: vv(3, "012"),
- },
-}
-
-func TestReassamble(t *testing.T) {
- for _, c := range reassambleTestCases {
- t.Run(c.comment, func(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- for _, f := range c.in {
- heap.Push(&h, f)
- }
- got, err := h.reassemble()
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
- }
- })
- }
-}
-
-func TestReassambleFailsForNonZeroOffset(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when the first packet had offset != 0")
- }
-}
-
-func TestReassambleFailsForHoles(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
- heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when there was a hole in the packet")
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index c75ca7d71..1af87d713 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -46,9 +46,17 @@ const (
)
var (
- // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // ErrInvalidArgs indicates to the caller that an invalid argument was
// provided.
ErrInvalidArgs = errors.New("invalid args")
+
+ // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
+ // with another one.
+ ErrFragmentOverlap = errors.New("overlapping fragments")
+
+ // ErrFragmentConflict indicates that, during reassembly, some fragments are
+ // in conflict with one another.
+ ErrFragmentConflict = errors.New("conflicting fragments")
)
// FragmentID is the identifier for a fragment.
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 19f4920b3..9b20bb1d8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -15,9 +15,8 @@
package fragmentation
import (
- "container/heap"
- "fmt"
"math"
+ "sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,9 +25,11 @@ import (
)
type hole struct {
- first uint16
- last uint16
- deleted bool
+ first uint16
+ last uint16
+ filled bool
+ final bool
+ data buffer.View
}
type reassembler struct {
@@ -38,8 +39,7 @@ type reassembler struct {
proto uint8
mu sync.Mutex
holes []hole
- deleted int
- heap fragHeap
+ filled int
done bool
creationTime int64
pkt *stack.PacketBuffer
@@ -48,49 +48,94 @@ type reassembler struct {
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
- holes: make([]hole, 0, 16),
- heap: make(fragHeap, 0, 8),
creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
- first: 0,
- last: math.MaxUint16,
- deleted: false})
+ first: 0,
+ last: math.MaxUint16,
+ filled: false,
+ final: true,
+ })
return r
}
-// updateHoles updates the list of holes for an incoming fragment and
-// returns true iff the fragment filled at least part of an existing hole.
-func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
- used := false
- for i := range r.holes {
- if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
- continue
- }
- used = true
- r.deleted++
- r.holes[i].deleted = true
- if first > r.holes[i].first {
- r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
- }
- if last < r.holes[i].last && more {
- r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
- }
- }
- return used
-}
-
func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
- consumed := 0
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, 0, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, 0, nil
}
- if r.updateHoles(first, last, more) {
+
+ var holeFound bool
+ var consumed int
+ for i := range r.holes {
+ currentHole := &r.holes[i]
+
+ if last < currentHole.first || currentHole.last < first {
+ continue
+ }
+ // For IPv6, overlaps with an existing fragment are explicitly forbidden by
+ // RFC 8200 section 4.5:
+ // If any of the fragments being reassembled overlap with any other
+ // fragments being reassembled for the same packet, reassembly of that
+ // packet must be abandoned and all the fragments that have been received
+ // for that packet must be discarded, and no ICMP error messages should be
+ // sent.
+ //
+ // It is not explicitly forbidden for IPv4, but to keep parity with Linux we
+ // disallow it as well:
+ // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
+ if first < currentHole.first || currentHole.last < last {
+ // Incoming fragment only partially fits in the free hole.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap
+ }
+ if !more {
+ if !currentHole.final || currentHole.filled && currentHole.last != last {
+ // We have another final fragment, which does not perfectly overlap.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ }
+ }
+
+ holeFound = true
+ if currentHole.filled {
+ // Incoming fragment is a duplicate.
+ continue
+ }
+
+ // We are populating the current hole with the payload and creating a new
+ // hole for any unfilled ranges on either end.
+ if first > currentHole.first {
+ r.holes = append(r.holes, hole{
+ first: currentHole.first,
+ last: first - 1,
+ filled: false,
+ final: false,
+ })
+ }
+ if last < currentHole.last && more {
+ r.holes = append(r.holes, hole{
+ first: last + 1,
+ last: currentHole.last,
+ filled: false,
+ final: currentHole.final,
+ })
+ currentHole.final = false
+ }
+ v := pkt.Data.ToOwnedView()
+ consumed = v.Size()
+ r.size += consumed
+ // Update the current hole to precisely match the incoming fragment.
+ r.holes[i] = hole{
+ first: first,
+ last: last,
+ filled: true,
+ final: currentHole.final,
+ data: v,
+ }
+ r.filled++
// For IPv6, it is possible to have different Protocol values between
// fragments of a packet (because, unlike IPv4, the Protocol is not used to
// identify a fragment). In this case, only the Protocol of the first
@@ -103,21 +148,30 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
r.pkt = pkt
r.proto = proto
}
- vv := pkt.Data
- // We store the incoming packet only if it filled some holes.
- heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
- consumed = vv.Size()
- r.size += consumed
+
+ break
+ }
+ if !holeFound {
+ // Incoming fragment is beyond end.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
}
- // Check if all the holes have been deleted and we are ready to reassamble.
- if r.deleted < len(r.holes) {
+
+ // Check if all the holes have been filled and we are ready to reassemble.
+ if r.filled < len(r.holes) {
return buffer.VectorisedView{}, 0, false, consumed, nil
}
- res, err := r.heap.reassemble()
- if err != nil {
- return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
+
+ sort.Slice(r.holes, func(i, j int) bool {
+ return r.holes[i].first < r.holes[j].first
+ })
+
+ var size int
+ views := make([]buffer.View, 0, len(r.holes))
+ for _, hole := range r.holes {
+ views = append(views, hole.data)
+ size += hole.data.Size()
}
- return res, r.proto, true, consumed, nil
+ return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index a0a04a027..2ff03eeeb 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -16,92 +16,175 @@ package fragmentation
import (
"math"
- "reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-type updateHolesInput struct {
- first uint16
- last uint16
- more bool
+type processParams struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ wantDone bool
+ wantError error
}
-var holesTestCases = []struct {
- comment string
- in []updateHolesInput
- want []hole
-}{
- {
- comment: "No fragments. Expected holes: {[0 -> inf]}.",
- in: []updateHolesInput{},
- want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
- },
- {
- comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
- in: []updateHolesInput{{first: 0, last: 1, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: false},
+func TestReassemblerProcess(t *testing.T) {
+ const proto = 99
+
+ v := func(size int) buffer.View {
+ payload := buffer.NewView(size)
+ for i := 1; i < size; i++ {
+ payload[i] = uint8(i) * 3
+ }
+ return payload
+ }
+
+ pkt := func(size int) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v(size).ToVectorisedView(),
+ })
+ }
+
+ var tests = []struct {
+ name string
+ params []processParams
+ want []hole
+ }{
+ {
+ name: "No fragments",
+ params: nil,
+ want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
},
- },
- {
- comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
- {first: 3, last: math.MaxUint16, deleted: false},
+ {
+ name: "One fragment at beginning",
+ params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment at the end. Expected holes: {[0, 0]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
+ {
+ name: "One fragment in the middle",
+ params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: false, data: v(2)},
+ {first: 0, last: 0, filled: false, final: false},
+ {first: 3, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment completing a packet. Expected holes: {}.",
- in: []updateHolesInput{{first: 0, last: 1, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
+ {
+ name: "One fragment at the end",
+ params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: true, data: v(2)},
+ {first: 0, last: 0, filled: false},
+ },
},
- },
- {
- comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 1, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "One fragment completing a packet",
+ params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- },
- {
- comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 2, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "Two fragments completing a packet with a duplicate",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 3, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet with a partial duplicate",
+ params: []processParams{
+ {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
+ {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 3, filled: true, final: false, data: v(4)},
+ {first: 4, last: 5, filled: true, final: true, data: v(2)},
+ },
},
- },
-}
+ {
+ name: "Two overlapping fragments",
+ params: []processParams{
+ {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
+ {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true, final: false, data: v(11)},
+ {first: 11, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "Two final fragments with different ends",
+ params: []processParams{
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 10, last: 14, filled: true, final: true, data: v(5)},
+ {first: 0, last: 9, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate, with different ends",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ }
-func TestUpdateHoles(t *testing.T) {
- for _, c := range holesTestCases {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- for _, i := range c.in {
- r.updateHoles(i.first, i.last, i.more)
- }
- if !reflect.DeepEqual(r.holes, c.want) {
- t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ for _, param := range test.params {
+ _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
+ if done != param.wantDone || err != param.wantError {
+ t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
+ }
+ }
+ if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
new file mode 100644
index 000000000..ca1247c1e
--- /dev/null
+++ b/pkg/tcpip/network/ip/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ip",
+ srcs = ["generic_multicast_protocol.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ ],
+)
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = ["generic_multicast_protocol_test.go"],
+ deps = [
+ ":ip",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/faketime",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
new file mode 100644
index 000000000..f2f0e069c
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -0,0 +1,676 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ip holds IPv4/IPv6 common utilities.
+package ip
+
+import (
+ "fmt"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// hostState is the state a host may be in for a multicast group.
+type hostState int
+
+// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
+// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
+// and MLDv1, IGMPv2 terminology will be used.
+//
+// ______________receive query______________
+// | |
+// | _____send or receive report_____ |
+// | | | |
+// V | V |
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+ |
+// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | -
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+
+// | ^ | ^ | ^ | ^
+// | | | | | | | |
+// ---------- ------- ---------- -------------
+// initialize new send inital fail to send send or receive
+// group membership report delayed report report
+//
+// Not shown in the diagram above, but any state may transition into the non
+// member state when a group is left.
+const (
+ // nonMember is the "'Non-Member' state, when the host does not belong to the
+ // group on the interface. This is the initial state for all memberships on
+ // all network interfaces; it requires no storage in the host."
+ //
+ // 'Non-Listener' is the MLDv1 term used to describe this state.
+ //
+ // This state is used to keep track of groups that have been joined locally,
+ // but without advertising the membership to the network.
+ nonMember hostState = iota
+
+ // pendingMember is a newly joined member that is waiting to successfully send
+ // the initial set of reports.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the initial report needs to be sent.
+ //
+ // MAY NOT transition to the idle member state from this state.
+ pendingMember
+
+ // delayingMember is the "'Delaying Member' state, when the host belongs to
+ // the group on the interface and has a report delay timer running for that
+ // membership."
+ //
+ // 'Delaying Listener' is the MLDv1 term used to describe this state.
+ delayingMember
+
+ // queuedDelayingMember is a delayingMember that failed to send a report after
+ // its delayed report timer fired. Hosts in this state are waiting to attempt
+ // retransmission of the delayed report.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the delayed report needs to be sent.
+ //
+ // May transition to idle member if a report is received for a group.
+ queuedDelayingMember
+
+ // idleMember is the "Idle Member" state, when the host belongs to the group
+ // on the interface and does not have a report delay timer running for that
+ // membership.
+ //
+ // 'Idle Listener' is the MLDv1 term used to describe this state.
+ idleMember
+)
+
+func (s hostState) isDelayingMember() bool {
+ switch s {
+ case nonMember, pendingMember, idleMember:
+ return false
+ case delayingMember, queuedDelayingMember:
+ return true
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", s))
+ }
+}
+
+// multicastGroupState holds the Generic Multicast Protocol state for a
+// multicast group.
+type multicastGroupState struct {
+ // joins is the number of times the group has been joined.
+ joins uint64
+
+ // state holds the host's state for the group.
+ state hostState
+
+ // lastToSendReport is true if we sent the last report for the group. It is
+ // used to track whether there are other hosts on the subnet that are also
+ // members of the group.
+ //
+ // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page
+ // 8 for MLDv1.
+ lastToSendReport bool
+
+ // delayedReportJob is used to delay sending responses to membership report
+ // messages in order to reduce duplicate reports from multiple hosts on the
+ // interface.
+ //
+ // Must not be nil.
+ delayedReportJob *tcpip.Job
+}
+
+// GenericMulticastProtocolOptions holds options for the generic multicast
+// protocol.
+type GenericMulticastProtocolOptions struct {
+ // Rand is the source of random numbers.
+ Rand *rand.Rand
+
+ // Clock is the clock used to create timers.
+ Clock tcpip.Clock
+
+ // Protocol is the implementation of the variant of multicast group protocol
+ // in use.
+ Protocol MulticastGroupProtocol
+
+ // MaxUnsolicitedReportDelay is the maximum amount of time to wait between
+ // transmitting unsolicited reports.
+ //
+ // Unsolicited reports are transmitted when a group is newly joined.
+ MaxUnsolicitedReportDelay time.Duration
+
+ // AllNodesAddress is a multicast address that all nodes on a network should
+ // be a member of.
+ //
+ // This address will not have the generic multicast protocol performed on it;
+ // it will be left in the non member/listener state, and packets will never
+ // be sent for it.
+ AllNodesAddress tcpip.Address
+}
+
+// MulticastGroupProtocol is a multicast group protocol whose core state machine
+// can be represented by GenericMulticastProtocolState.
+type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
+ // SendReport sends a multicast report for the specified group address.
+ //
+ // Returns false if the caller should queue the report to be sent later. Note,
+ // returning false does not mean that the receiver hit an error.
+ SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
+
+ // SendLeave sends a multicast leave for the specified group address.
+ SendLeave(groupAddress tcpip.Address) *tcpip.Error
+}
+
+// GenericMulticastProtocolState is the per interface generic multicast protocol
+// state.
+//
+// There is actually no protocol named "Generic Multicast Protocol". Instead,
+// the term used to refer to a generic multicast protocol that applies to both
+// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
+// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
+//
+// Callers must synchronize accesses to the generic multicast protocol state;
+// GenericMulticastProtocolState obtains no locks in any of its methods. The
+// only exception to this is GenericMulticastProtocolState's timer/job callbacks
+// which will obtain the lock provided to the GenericMulticastProtocolState when
+// it is initialized.
+//
+// GenericMulticastProtocolState.Init MUST be called before calling any of
+// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
+type GenericMulticastProtocolState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
+ opts GenericMulticastProtocolOptions
+
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
+
+ // protocolMU is the mutex used to protect the protocol.
+ protocolMU *sync.RWMutex
+}
+
+// Init initializes the Generic Multicast Protocol state.
+//
+// Must only be called once for the lifetime of g; Init will panic if it is
+// called twice.
+//
+// The GenericMulticastProtocolState will only grab the lock when timers/jobs
+// fire.
+//
+// Note: the methods on opts.Protocol will always be called while protocolMU is
+// held.
+func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
+ if g.memberships != nil {
+ panic("attempted to initialize generic membership protocol state twice")
+ }
+
+ *g = GenericMulticastProtocolState{
+ opts: opts,
+ memberships: make(map[tcpip.Address]multicastGroupState),
+ protocolMU: protocolMU,
+ }
+}
+
+// MakeAllNonMemberLocked transitions all groups to the non-member state.
+//
+// The groups will still be considered joined locally.
+//
+// MUST be called when the multicast group protocol is disabled.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// InitializeGroupsLocked initializes each group, as if they were newly joined
+// but without affecting the groups' join count.
+//
+// Must only be called after calling MakeAllNonMember as a group should not be
+// initialized while it is not in the non-member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// SendQueuedReportsLocked attempts to send reports for groups that failed to
+// send reports during their last attempt.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
+ for groupAddress, info := range g.memberships {
+ switch info.state {
+ case nonMember, delayingMember, idleMember:
+ case pendingMember:
+ // pendingMembers failed to send their initial unsolicited report so try
+ // to send the report and queue the extra unsolicited reports.
+ g.maybeSendInitialReportLocked(groupAddress, &info)
+ case queuedDelayingMember:
+ // queuedDelayingMembers failed to send their delayed reports so try to
+ // send the report and transition them to the idle state.
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", info.state))
+ }
+ g.memberships[groupAddress] = info
+ }
+}
+
+// JoinGroupLocked handles joining a new group.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
+ if info, ok := g.memberships[groupAddress]; ok {
+ // The group has already been joined.
+ info.joins++
+ g.memberships[groupAddress] = info
+ return
+ }
+
+ info := multicastGroupState{
+ // Since we just joined the group, its count is 1.
+ joins: 1,
+ // The state will be updated below, if required.
+ state: nonMember,
+ lastToSendReport: false,
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
+ }
+
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }),
+ }
+
+ if g.opts.Protocol.Enabled() {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ }
+
+ g.memberships[groupAddress] = info
+}
+
+// IsLocallyJoinedRLocked returns true if the group is locally joined.
+//
+// Precondition: g.protocolMU must be read locked.
+func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool {
+ _, ok := g.memberships[groupAddress]
+ return ok
+}
+
+// LeaveGroupLocked handles leaving the group.
+//
+// Returns false if the group is not currently joined.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool {
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ return false
+ }
+
+ if info.joins == 0 {
+ panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress))
+ }
+ info.joins--
+ if info.joins != 0 {
+ // If we still have outstanding joins, then do nothing further.
+ g.memberships[groupAddress] = info
+ return true
+ }
+
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ delete(g.memberships, groupAddress)
+ return true
+}
+
+// HandleQueryLocked handles a query message with the specified maximum response
+// time.
+//
+// If the group address is unspecified, then reports will be scheduled for all
+// joined groups.
+//
+// Report(s) will be scheduled to be sent after a random duration between 0 and
+// the maximum response time.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 2.4 (for IGMPv2),
+ //
+ // In a Membership Query message, the group address field is set to zero
+ // when sending a General Query, and set to the group address being
+ // queried when sending a Group-Specific Query.
+ //
+ // As per RFC 2710 section 3.6 (for MLDv1),
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ if groupAddress.Unspecified() {
+ // This is a general query as the group address is unspecified.
+ for groupAddress, info := range g.memberships {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+ } else if info, ok := g.memberships[groupAddress]; ok {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// HandleReportLocked handles a report message.
+//
+// If the report is for a joined group, any active delayed report will be
+// cancelled and the host state for the group transitions to idle.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
+ //
+ // If the host receives another host's Report (version 1 or 2) while it has
+ // a timer running, it stops its timer for the specified group and does not
+ // send a Report
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // If a node receives another node's Report from an interface for a
+ // multicast address while it has a timer running for that same address
+ // on that interface, it stops its timer and does not send a Report for
+ // that address, thus suppressing duplicate reports on the link.
+ if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
+ info.delayedReportJob.Cancel()
+ info.lastToSendReport = false
+ info.state = idleMember
+ g.memberships[groupAddress] = info
+ }
+}
+
+// initializeNewMemberLocked initializes a new group membership.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != nonMember {
+ panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ info.lastToSendReport = false
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ info.state = idleMember
+ return
+ }
+
+ info.state = pendingMember
+ g.maybeSendInitialReportLocked(groupAddress, info)
+}
+
+// maybeSendInitialReportLocked attempts to start transmission of the initial
+// set of reports after newly joining a group.
+//
+// Host must be in pending member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != pendingMember {
+ panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a host joins a multicast group, it should immediately transmit an
+ // unsolicited Version 2 Membership Report for that group" ... "it is
+ // recommended that it be repeated".
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node starts listening to a multicast address on an interface,
+ // it should immediately transmit an unsolicited Report for that address
+ // on that interface, in case it is the first listener on the link. To
+ // cover the possibility of the initial Report being lost or damaged, it
+ // is recommended that it be repeated once or twice after short delays
+ // [Unsolicited Report Interval].
+ //
+ // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
+ // unsolicited reports.
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ }
+}
+
+// maybeSendDelayedReportLocked attempts to send the delayed report.
+//
+// Host must be in pending, delaying or queued delaying member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if !info.state.isDelayingMember() {
+ panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ info.state = idleMember
+ } else {
+ info.state = queuedDelayingMember
+ }
+}
+
+// maybeSendLeave attempts to send a leave message.
+func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // Okay to ignore the error here as if packet write failed, the multicast
+ // routers will eventually drop our membership anyways. If the interface is
+ // being disabled or removed, the generic multicast protocol's should be
+ // cleared eventually.
+ //
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a router receives a Report, it adds the group being reported to
+ // the list of multicast group memberships on the network on which it
+ // received the Report and sets the timer for the membership to the
+ // [Group Membership Interval]. Repeated Reports refresh the timer. If
+ // no Reports are received for a particular group before this timer has
+ // expired, the router assumes that the group has no local members and
+ // that it need not forward remotely-originated multicasts for that
+ // group onto the attached network.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // When a router receives a Report from a link, if the reported address
+ // is not already present in the router's list of multicast address
+ // having listeners on that link, the reported address is added to the
+ // list, its timer is set to [Multicast Listener Interval], and its
+ // appearance is made known to the router's multicast routing component.
+ // If a Report is received for a multicast address that is already
+ // present in the router's list, the timer for that address is reset to
+ // [Multicast Listener Interval]. If an address's timer expires, it is
+ // assumed that there are no longer any listeners for that address
+ // present on the link, so it is deleted from the list and its
+ // disappearance is made known to the multicast routing component.
+ //
+ // The requirement to send a leave message is also optional (it MAY be
+ // skipped):
+ //
+ // As per RFC 2236 section 6 page 8 (for IGMPv2),
+ //
+ // "send leave" for the group on the interface. If the interface
+ // state says the Querier is running IGMPv1, this action SHOULD be
+ // skipped. If the flag saying we were the last host to report is
+ // cleared, this action MAY be skipped. The Leave Message is sent to
+ // the ALL-ROUTERS group (224.0.0.2).
+ //
+ // As per RFC 2710 section 5 page 8 (for MLDv1),
+ //
+ // "send done" for the address on the interface. If the flag saying
+ // we were the last node to report is cleared, this action MAY be
+ // skipped. The Done message is sent to the link-scope all-routers
+ // address (FF02::2).
+ _ = g.opts.Protocol.SendLeave(groupAddress)
+}
+
+// transitionToNonMemberLocked transitions the given multicast group the the
+// non-member/listener state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state == nonMember {
+ return
+ }
+
+ info.delayedReportJob.Cancel()
+ g.maybeSendLeave(groupAddress, info.lastToSendReport)
+ info.lastToSendReport = false
+ info.state = nonMember
+}
+
+// setDelayTimerForAddressRLocked sets timer to send a delay report.
+//
+// Precondition: g.protocolMU MUST be read locked.
+func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
+ if info.state == nonMember {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // If a timer for the group is already unning, it is reset to the random
+ // value only if the requested Max Response Time is less than the remaining
+ // value of the running timer.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // If a timer for any address is already running, it is reset to the new
+ // random value only if the requested Maximum Response Delay is less than
+ // the remaining value of the running timer.
+ if info.state == delayingMember {
+ // TODO: Reset the timer if time remaining is greater than maxResponseTime.
+ return
+ }
+
+ info.state = delayingMember
+ info.delayedReportJob.Cancel()
+ info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
+}
+
+// calculateDelayTimerDuration returns a random time between (0, maxRespTime].
+func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration {
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // When a host receives a Group-Specific Query, it sets a delay timer to a
+ // random value selected from the range (0, Max Response Time]...
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node receives a Multicast-Address-Specific Query, if it is
+ // listening to the queried Multicast Address on the interface from
+ // which the Query was received, it sets a delay timer for that address
+ // to a random value selected from the range [0, Maximum Response Delay],
+ // as above.
+ if maxRespTime == 0 {
+ return 0
+ }
+ return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime)))
+}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
new file mode 100644
index 000000000..f56f7aa90
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -0,0 +1,877 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+)
+
+const (
+ addr1 = tcpip.Address("\x01")
+ addr2 = tcpip.Address("\x02")
+ addr3 = tcpip.Address("\x03")
+ addr4 = tcpip.Address("\x04")
+
+ maxUnsolicitedReportDelay = time.Second
+)
+
+var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
+
+type mockMulticastGroupProtocol struct {
+ t *testing.T
+
+ mu sync.RWMutex
+
+ // Must only be accessed with mu held.
+ sendReportGroupAddrCount map[tcpip.Address]int
+
+ // Must only be accessed with mu held.
+ sendLeaveGroupAddrCount map[tcpip.Address]int
+
+ // Must only be accessed with mu held.
+ makeQueuePackets bool
+
+ // Must only be accessed with mu held.
+ disabled bool
+}
+
+func (m *mockMulticastGroupProtocol) init() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.initLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initLocked() {
+ m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.disabled = !v
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be read locked.
+func (m *mockMulticastGroupProtocol) Enabled() bool {
+ return !m.disabled
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+
+ m.sendReportGroupAddrCount[groupAddress]++
+ return !m.makeQueuePackets, nil
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+
+ m.sendLeaveGroupAddrCount[groupAddress]++
+ return nil
+}
+
+func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendReportGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendReportGroupAddresses {
+ sendReportGroupAddrCount[a] = 1
+ }
+
+ sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendLeaveGroupAddresses {
+ sendLeaveGroupAddrCount[a] = 1
+ }
+
+ diff := cmp.Diff(
+ &mockMulticastGroupProtocol{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
+ m,
+ cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
+ cmp.FilterPath(
+ func(p cmp.Path) bool {
+ switch p.Last().String() {
+ case ".mu", ".t", ".makeQueuePackets", ".disabled":
+ return true
+ }
+ return false
+ },
+ cmp.Ignore(),
+ ),
+ )
+ m.initLocked()
+ return diff
+}
+
+func TestJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendReports bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendReports: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendReports: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(0)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ // Joining a group should send a report immediately and another after
+ // a random interval between 0 and the maximum unsolicited report delay.
+ mgp.mu.Lock()
+ g.JoinGroupLocked(test.addr)
+ mgp.mu.Unlock()
+ if test.shouldSendReports {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendMessages bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendMessages: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendMessages: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(1)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(test.addr)
+ mgp.mu.Unlock()
+ if test.shouldSendMessages {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Leaving a group should send a leave report immediately and cancel any
+ // delayed reports.
+ {
+ mgp.mu.Lock()
+ res := g.LeaveGroupLocked(test.addr)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr)
+ }
+ }
+ if test.shouldSendMessages {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleReport(t *testing.T) {
+ tests := []struct {
+ name string
+ reportAddr tcpip.Address
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ reportAddr: "",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ reportAddr: "\x00",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ reportAddr: addr1,
+ expectReportsFor: []tcpip.Address{addr2},
+ },
+ {
+ name: "Specified all-nodes",
+ reportAddr: addr3,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified other",
+ reportAddr: addr4,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(2)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report for a group we have a timer scheduled for should
+ // cancel our delayed report timer for the group.
+ mgp.mu.Lock()
+ g.HandleReportLocked(test.reportAddr)
+ mgp.mu.Unlock()
+ if len(test.expectReportsFor) != 0 {
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleQuery(t *testing.T) {
+ tests := []struct {
+ name string
+ queryAddr tcpip.Address
+ maxDelay time.Duration
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ queryAddr: "",
+ maxDelay: 0,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ queryAddr: "\x00",
+ maxDelay: 1,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ queryAddr: addr1,
+ maxDelay: 2,
+ expectReportsFor: []tcpip.Address{addr1},
+ },
+ {
+ name: "Specified all-nodes",
+ queryAddr: addr3,
+ maxDelay: 3,
+ expectReportsFor: nil,
+ },
+ {
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectReportsFor: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should make us schedule a new delayed report if it
+ // is a query directed at us or a general query.
+ mgp.mu.Lock()
+ g.HandleQueryLocked(test.queryAddr, test.maxDelay)
+ mgp.mu.Unlock()
+ if len(test.expectReportsFor) != 0 {
+ clock.Advance(test.maxDelay)
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestJoinCount(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: time.Second,
+ })
+
+ // Set the join count to 2 for a group.
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ // Only the first join should trigger a report to be sent.
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should still be considered joined after leaving once.
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
+ }
+ if !isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ // A leave report should only be sent once the join count reaches 0.
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving once more should actually remove us from the group.
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
+ }
+ if isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should no longer be joined so we should not have anything to
+ // leave.
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1)
+ }
+ if isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestMakeAllNonMemberAndInitialize(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should send the leave reports for each but still consider them locally
+ // joined.
+ mgp.mu.Lock()
+ g.MakeAllNonMemberLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ for _, group := range []tcpip.Address{addr1, addr2, addr3} {
+ mgp.mu.RLock()
+ res := g.IsLocallyJoinedRLocked(group)
+ mgp.mu.RUnlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group)
+ }
+ }
+
+ // Should send the initial set of unsolcited reports.
+ mgp.mu.Lock()
+ g.InitializeGroupsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// TestGroupStateNonMember tests that groups do not send packets when in the
+// non-member state, but are still considered locally joined.
+func TestGroupStateNonMember(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init()
+ mgp.setEnabled(false)
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining groups should not send any reports.
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ res := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should not send any reports.
+ mgp.mu.Lock()
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Leaving groups should not send any leave messages.
+ {
+ mgp.mu.Lock()
+ addr2LeaveRes := g.LeaveGroupLocked(addr2)
+ addr1IsJoined := g.IsLocallyJoinedRLocked(addr1)
+ addr2IsJoined := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !addr2LeaveRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2)
+ }
+ if !addr1IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ if addr2IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestQueuedPackets(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining should trigger a SendReport, but mgp should report that we did not
+ // send the packet.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.JoinGroupLocked(addr1)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report timer should have been cancelled since we did not send
+ // the initial report earlier.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to successfully send the report.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send (we should be idle).
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query but mock being unable to send reports again.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to send reports again - we should have a packet queued to
+ // send.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query again, but mock being unable to send reports.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report should should transition us into the idle member state,
+ // even if we had a packet queued. We should no longer have any packets to
+ // send.
+ mgp.mu.Lock()
+ g.HandleReportLocked(addr1)
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // When we fail to send the initial set of reports, incoming reports should
+ // not affect a newly joined group's reports from being sent.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.JoinGroupLocked(addr2)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.HandleReportLocked(addr2)
+ // Attempting to send queued reports while still unable to send reports should
+ // not change the host state.
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Mock being able to successfully send the report.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index d49c44846..3005973d7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -193,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu
panic("not implemented")
}
-func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
@@ -207,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -223,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -348,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -554,7 +550,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -623,11 +619,11 @@ func TestReceive(t *testing.T) {
view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadLen,
- NextHeader: 10,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
+ PayloadLength: payloadLen,
+ TransportProtocol: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -937,7 +933,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -997,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Create the outer IPv6 header.
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIPv6Addr,
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 20,
+ SrcAddr: outerSrcAddr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -1011,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp.SetIdent(0xdead)
icmp.SetSequence(0xbeef)
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: localIPv6Addr,
- DstAddr: remoteIPv6Addr,
- })
-
+ var extHdrs header.IPv6ExtHdrSerializer
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
- ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
- frag.Encode(&header.IPv6FragmentFields{
- NextHeader: 10,
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
FragmentOffset: *c.fragmentOffset,
M: true,
Identification: 0x12345678,
})
}
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
+ ExtensionHeaders: extHdrs,
+ })
+
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
@@ -1093,7 +1088,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
dataBuf := [dataLen]byte{1, 2, 3, 4}
data := dataBuf[:]
- ipv4Options := header.IPv4Options{0, 1, 0, 1}
+ ipv4Options := header.IPv4OptionsSerializer{
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ }
+
+ expectOptions := header.IPv4Options{
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ }
ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
@@ -1243,7 +1250,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding()
+ ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1266,7 +1273,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1276,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1288,7 +1295,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding()))
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
ip.Encode(&header.IPv4Fields{
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
@@ -1307,7 +1314,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1317,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1336,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1379,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ // NB: we're lying about transport protocol here to verify the raw
+ // fragment header bytes.
+ TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1414,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1449,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 6252614ec..32f53f217 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "ipv4",
srcs = [
"icmp.go",
+ "igmp.go",
"ipv4.go",
],
visibility = ["//visibility:public"],
@@ -17,6 +18,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -24,7 +26,10 @@ go_library(
go_test(
name = "ipv4_test",
size = "small",
- srcs = ["ipv4_test.go"],
+ srcs = [
+ "igmp_test.go",
+ "ipv4_test.go",
+ ],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 488945226..8e392f86c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4PacketsReceived
+ received := stats.ICMP.V4.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.Echo.Increment()
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if !e.protocol.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return
@@ -379,7 +379,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4PacketsSent
+ sent := p.stack.Stats().ICMP.V4.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
new file mode 100644
index 000000000..da88d65d1
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -0,0 +1,345 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // igmpV1PresentDefault is the initial state for igmpV1Present in the
+ // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is
+ // the initial state."
+ igmpV1PresentDefault = 0
+
+ // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18
+ // See note on igmpState.igmpV1Present for more detail.
+ v1RouterPresentTimeout = 400 * time.Second
+
+ // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router
+ // will send General Queries with the Max Response Time set to 0. This MUST
+ // be interpreted as a value of 100 (10 seconds)."
+ //
+ // Note that the Max Response Time field is a value in units of deciseconds.
+ v1MaxRespTime = 10 * time.Second
+
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited IGMP reports.
+ //
+ // Obtained from RFC 2236 Section 8.10, Page 19.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// IGMPOptions holds options for IGMP.
+type IGMPOptions struct {
+ // Enabled indicates whether IGMP will be performed.
+ //
+ // When enabled, IGMP may transmit IGMP report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // IGMP packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
+
+// igmpState is the per-interface IGMP state.
+//
+// igmpState.init() MUST be called after creating an IGMP state.
+type igmpState struct {
+ // The IPv4 endpoint this igmpState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+
+ // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
+ // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
+ // Membership Reports in response to its Queries, and will not pay
+ // attention to Version 2 Membership Reports. Therefore, a state variable
+ // MUST be kept for each interface, describing whether the multicast
+ // Querier on that interface is running IGMPv1 or IGMPv2. This variable
+ // MUST be based upon whether or not an IGMPv1 query was heard in the last
+ // [Version 1 Router Present Timeout] seconds".
+ //
+ // Must be accessed with atomic operations. Holds a value of 1 when true, 0
+ // when false.
+ igmpV1Present uint32
+
+ // igmpV1Job is scheduled when this interface receives an IGMPv1 style
+ // message, upon expiration the igmpV1Present flag is cleared.
+ // igmpV1Job may not be nil once igmpState is initialized.
+ igmpV1Job *tcpip.Job
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) Enabled() bool {
+ // No need to perform IGMP on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ igmpType := header.IGMPv2MembershipReport
+ if igmp.v1Present() {
+ igmpType = header.IGMPv1MembershipReport
+ }
+ return igmp.writePacket(groupAddress, groupAddress, igmpType)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ // As per RFC 2236 Section 6, Page 8: "If the interface state says the
+ // Querier is running IGMPv1, this action SHOULD be skipped. If the flag
+ // saying we were the last host to report is cleared, this action MAY be
+ // skipped."
+ if igmp.v1Present() {
+ return nil
+ }
+ _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ return err
+}
+
+// init sets up an igmpState struct, and is required to be called before using
+// a new igmpState.
+//
+// Must only be called once for the lifetime of igmp.
+func (igmp *igmpState) init(ep *endpoint) {
+ igmp.ep = ep
+ igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: igmp,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv4AllSystems,
+ })
+ igmp.igmpV1Present = igmpV1PresentDefault
+ igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
+ igmp.setV1Present(false)
+ })
+}
+
+// handleIGMP handles an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
+ stats := igmp.ep.protocol.stack.Stats()
+ received := stats.IGMP.PacketsReceived
+ headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.IGMP(headerView)
+
+ // Temporarily reset the checksum field to 0 in order to calculate the proper
+ // checksum.
+ wantChecksum := h.Checksum()
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(wantChecksum)
+
+ if gotChecksum != wantChecksum {
+ received.ChecksumErrors.Increment()
+ return
+ }
+
+ switch h.Type() {
+ case header.IGMPMembershipQuery:
+ received.MembershipQuery.Increment()
+ if len(headerView) < header.IGMPQueryMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
+ case header.IGMPv1MembershipReport:
+ received.V1MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPv2MembershipReport:
+ received.V2MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPLeaveGroup:
+ received.LeaveGroup.Increment()
+ // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
+ // Report, are ignored in all states"
+
+ default:
+ // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
+ // be silently ignored. New message types may be used by newer versions of
+ // IGMP, by multicast routing protocols, or other uses."
+ received.Unrecognized.Increment()
+ }
+}
+
+func (igmp *igmpState) v1Present() bool {
+ return atomic.LoadUint32(&igmp.igmpV1Present) == 1
+}
+
+func (igmp *igmpState) setV1Present(v bool) {
+ if v {
+ atomic.StoreUint32(&igmp.igmpV1Present, 1)
+ } else {
+ atomic.StoreUint32(&igmp.igmpV1Present, 0)
+ }
+}
+
+// handleMembershipQuery handles a membership query.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
+ // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
+ // then change the state to note that an IGMPv1 router is present and
+ // schedule the query received Job.
+ if maxRespTime == 0 && igmp.Enabled() {
+ igmp.igmpV1Job.Cancel()
+ igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
+ igmp.setV1Present(true)
+ maxRespTime = v1MaxRespTime
+ }
+
+ igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime)
+}
+
+// handleMembershipReport handles a membership report.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
+ igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
+}
+
+// writePacket assembles and sends an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
+ igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
+ igmpData.SetType(igmpType)
+ igmpData.SetGroupAddress(groupAddress)
+ igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()),
+ Data: buffer.View(igmpData).ToVectorisedView(),
+ })
+
+ addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return false, nil
+ }
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
+ addressEndpoint = nil
+ igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.IGMPProtocolNumber,
+ TTL: header.IGMPTTL,
+ TOS: stack.DefaultTOS,
+ }, header.IPv4OptionsSerializer{
+ &header.IPv4SerializableRouterAlertOption{},
+ })
+
+ sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
+ }
+ switch igmpType {
+ case header.IGMPv1MembershipReport:
+ sentStats.V1MembershipReport.Increment()
+ case header.IGMPv2MembershipReport:
+ sentStats.V2MembershipReport.Increment()
+ case header.IGMPLeaveGroup:
+ sentStats.LeaveGroup.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
+ }
+ return true, nil
+}
+
+// joinGroup handles adding a new group to the membership map, setting up the
+// IGMP state for the group, and sending and scheduling the required
+// messages.
+//
+// If the group already exists in the membership map, returns
+// tcpip.ErrDuplicateAddress.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
+ return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Leave Group message
+// if required.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of IGMP, but remains
+// joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) softLeaveAll() {
+ igmp.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the IGMP state for each group that has
+// been joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) initializeAll() {
+ igmp.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) sendQueuedReports() {
+ igmp.genericMulticastProtocol.SendQueuedReportsLocked()
+}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
new file mode 100644
index 000000000..1ee573ac8
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -0,0 +1,215 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ addr = tcpip.Address("\x0a\x00\x00\x01")
+ multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
+ nicID = 1
+)
+
+// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set. Raises a t.Error if
+// any field does not match.
+func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.SrcAddr(addr),
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(igmpType),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ // Create an endpoint of queue size 1, since no more than 1 packets are ever
+ // queued in the tests in this file.
+ e := channel.New(1, 1280, linkAddr)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ return e, s, clock
+}
+
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: 1,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(igmpType)
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is
+// present on the network. The IGMP stack will then send IGMPv1 Membership
+// reports for backwards compatibility.
+func TestIgmpV1Present(t *testing.T) {
+ e, s, clock := createStack(t, true)
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
+ }
+
+ // This NIC will send an IGMPv2 report immediately, before this test can get
+ // the IGMPv1 General Membership Query in.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Inject an IGMPv1 General Membership Query which is identical to a standard
+ // membership query except the Max Response Time is set to 0, which will tell
+ // the stack that this is a router using IGMPv1. Send it to the all systems
+ // group which is the only group this host belongs to.
+ createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems)
+ if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
+ t.Fatalf("got Membership Queries received = %d, want = 1", got)
+ }
+
+ // Before advancing the clock, verify that this host has not sent a
+ // V1MembershipReport yet.
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
+ }
+
+ // Verify the solicited Membership Report is sent. Now that this NIC has seen
+ // an IGMPv1 query, it should send an IGMPv1 Membership Report.
+ p, ok = e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ p, ok = e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
+}
+
+func TestSendQueuedIGMPReports(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Joining a group without an assigned address should queue IGMP packets; none
+ // should be sent without an assigned address.
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
+ }
+ reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
+ if got := reportStat.Value(); got != 0 {
+ t.Errorf("got reportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+
+ // The initial set of IGMP reports that were queued should be sent once an
+ // address is assigned.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+ if got := reportStat.Value(); got != 1 {
+ t.Errorf("got reportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ if got := reportStat.Value(); got != 2 {
+ t.Errorf("got reportStat.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should have no more packets to send after the initial set of unsolicited
+ // reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1efe6297a..e9ff70d04 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -83,6 +83,7 @@ type endpoint struct {
sync.RWMutex
addressableEndpointState stack.AddressableEndpointState
+ igmp igmpState
}
}
@@ -93,7 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
+ e.mu.igmp.init(e)
+ e.mu.Unlock()
return e
}
@@ -121,11 +125,22 @@ func (e *endpoint) Enable() *tcpip.Error {
// We have no need for the address endpoint.
ep.DecRef()
+ // Groups may have been joined while the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of IGMP when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.igmp.initializeAll()
+
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
- _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
- return err
+ if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err))
+ }
+
+ return nil
}
// Enabled implements stack.NetworkEndpoint.
@@ -157,19 +172,27 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.isEnabled() {
return
}
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
+ // Leave groups from the perspective of IGMP so that routers know that
+ // we are no longer interested in the group.
+ e.mu.igmp.softLeaveAll()
+
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -198,37 +221,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) {
hdrLen := header.IPv4MinimumSize
- var opts header.IPv4Options
- if params.Options != nil {
- var ok bool
- if opts, ok = params.Options.(header.IPv4Options); !ok {
- panic(fmt.Sprintf("want IPv4Options, got %T", params.Options))
- }
- hdrLen += opts.SizeWithPadding()
- if hdrLen > header.IPv4MaximumHeaderSize {
- // Since we have no way to report an error we must either panic or create
- // a packet which is different to what was requested. Choose panic as this
- // would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize))
- }
+ var optLen int
+ if options != nil {
+ optLen = int(options.Length())
+ }
+ hdrLen += optLen
+ if hdrLen > header.IPv4MaximumHeaderSize {
+ // Since we have no way to report an error we must either panic or create
+ // a packet which is different to what was requested. Choose panic as this
+ // would be a programming error that should be caught in testing.
+ panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize))
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := uint16(pkt.Size())
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
- id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
TotalLength: length,
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
Protocol: uint8(params.Protocol),
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- Options: opts,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ Options: options,
})
ip.SetChecksum(^ip.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -259,7 +279,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -347,7 +367,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
@@ -461,7 +481,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
@@ -566,21 +586,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats.IP.MalformedPacketsReceived.Increment()
return
}
- srcAddr := h.SourceAddress()
- dstAddr := h.DestinationAddress()
-
- addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
- if addressEndpoint == nil {
- if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
- return
- }
-
- _ = e.forwardPacket(pkt)
- return
- }
- subnet := addressEndpoint.AddressWithPrefix().Subnet()
- addressEndpoint.DecRef()
// There has been some confusion regarding verifying checksums. We need
// just look for negative 0 (0xffff) as the checksum, as it's not possible to
@@ -608,16 +613,42 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
// As per RFC 1122 section 3.2.1.3:
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
+ if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+ // Make sure the source address is not a subnet-local broadcast address.
+ if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ if subnet.IsBroadcast(srcAddr) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
+
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ addressEndpoint.DecRef()
+ pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+ } else if !e.IsInGroup(dstAddr) {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
- pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+ _ = e.forwardPacket(pkt)
+ return
+ }
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
@@ -692,6 +723,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
e.handleICMP(pkt)
return
}
+ if p == header.IGMPProtocolNumber {
+ e.mu.Lock()
+ e.mu.igmp.handleIGMP(pkt)
+ e.mu.Unlock()
+ return
+ }
if opts := h.Options(); len(opts) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
@@ -747,7 +784,12 @@ func (e *endpoint) Close() {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err == nil {
+ e.mu.igmp.sendQueuedReports()
+ }
+ return ep, err
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
@@ -770,34 +812,26 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
defer e.mu.Unlock()
loopback := e.nic.IsLoopback()
- addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
- })
- if addressEndpoint != nil {
- return addressEndpoint
- }
-
- if !allowTemp {
- return nil
- }
-
- addr := localAddr.WithPrefix()
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
- if err != nil {
- // AddAddress only returns an error if the address is already assigned,
- // but we just checked above if the address exists so we expect no error.
- panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
- }
- return addressEndpoint
+ }, allowTemp, tempPEB)
}
// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
@@ -816,28 +850,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mu.igmp.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -863,6 +912,8 @@ type protocol struct {
hashIV uint32
fragmentation *fragmentation.Fragmentation
+
+ options Options
}
// Number returns the ipv4 protocol number.
@@ -987,17 +1038,23 @@ func addressToUint32(addr tcpip.Address) uint32 {
return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
}
-// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number and a 32-bit number to
-// generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- a := addressToUint32(r.LocalAddress)
- b := addressToUint32(r.RemoteAddress)
+// hashRoute calculates a hash value for the given source/destination pair using
+// the addresses, transport protocol number and a 32-bit number to generate the
+// hash.
+func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ a := addressToUint32(srcAddr)
+ b := addressToUint32(dstAddr)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+// Options holds options to configure a new protocol.
+type Options struct {
+ // IGMP holds options for IGMP.
+ IGMP IGMPOptions
+}
+
+// NewProtocolWithOptions returns an IPv4 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -1007,14 +1064,22 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
}
hashIV := r[buckets]
- p := &protocol{
- stack: s,
- ids: ids,
- hashIV: hashIV,
- defaultTTL: DefaultTTL,
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ options: opts,
+ }
+ p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ return p
}
- p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
- return p
+}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
}
func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
@@ -1129,6 +1194,12 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
}
pointer := tsOpt.Pointer()
+ // RFC 791 page 22 states: "The smallest legal value is 5."
+ // Since the pointer is 1 based, and the header is 4 bytes long the
+ // pointer must point beyond the header therefore 4 or less is bad.
+ if pointer <= header.IPv4OptionTimestampHdrLength {
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
// To simplify processing below, base further work on the array of timestamps
// beyond the header, rather than on the whole option. Also to aid
// calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
@@ -1215,7 +1286,15 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
}
- nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+ pointer := rrOpt.Pointer()
+ // RFC 791 page 20 states:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ // Since the pointer is 1 based, and the header is 3 bytes long the
+ // pointer must point beyond the header therefore 3 or less is bad.
+ if pointer <= header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
// RFC 791 page 21 says
// If the route data area is already full (the pointer exceeds the
@@ -1230,14 +1309,14 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// do this (as do most implementations). It is probable that the inclusion
// of these words is a copy/paste error from the timestamp option where
// there are two failure reasons given.
- if nextSlot >= optlen {
+ if pointer > optlen {
return 0, nil
}
// The data area isn't full but there isn't room for a new entry.
// Either Length or Pointer could be bad. We must select Pointer for Linux
- // compatibility, even if only the length is bad.
- if nextSlot+header.IPv4AddressSize > optlen {
+ // compatibility, even if only the length is bad. NB. pointer is 1 based.
+ if pointer+header.IPv4AddressSize > optlen+1 {
if false {
// This is what we would do if we were not being Linux compatible.
// Check for bad pointer or length value. Must be a multiple of 4 after
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 4e4e1f3b4..9e2d2cfd6 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -103,105 +103,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
-// fields when options are supplied.
-func TestIPv4EncodeOptions(t *testing.T) {
- tests := []struct {
- name string
- options header.IPv4Options
- encodedOptions header.IPv4Options // reply should look like this
- wantIHL int
- }{
- {
- name: "valid no options",
- wantIHL: header.IPv4MinimumSize,
- },
- {
- name: "one byte options",
- options: header.IPv4Options{1},
- encodedOptions: header.IPv4Options{1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "two byte options",
- options: header.IPv4Options{1, 1},
- encodedOptions: header.IPv4Options{1, 1, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "three byte options",
- options: header.IPv4Options{1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "four byte options",
- options: header.IPv4Options{1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "five byte options",
- options: header.IPv4Options{1, 1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 8,
- },
- {
- name: "thirty nine byte options",
- options: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39,
- },
- encodedOptions: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39, 0,
- },
- wantIHL: header.IPv4MinimumSize + 40,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength)
- hdr := buffer.NewPrependable(int(totalLen))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- // To check the padding works, poison the last byte of the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0xff
- ip.SetHeaderLength(0)
- }
- ip.Encode(&header.IPv4Fields{
- Options: test.options,
- })
- options := ip.Options()
- wantOptions := test.encodedOptions
- if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
- t.Errorf("got IHL of %d, want %d", got, want)
- }
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(wantOptions) == 0 && len(options) == 0 {
- return
- }
-
- if diff := cmp.Diff(wantOptions, options); diff != "" {
- t.Errorf("options mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
func TestForwarding(t *testing.T) {
const (
nicID1 = 1
@@ -453,14 +354,6 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{1, 1, 0, 0},
},
{
- name: "Check option padding",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 1},
- replyOptions: header.IPv4Options{1, 1, 1, 0},
- },
- {
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
maxTotalLength: ipv4.MaxTotalSize,
@@ -583,7 +476,7 @@ func TestIPv4Sanity(t *testing.T) {
68, 7, 5, 0,
// ^ ^ Linux points here which is wrong.
// | Not a multiple of 4
- 1, 2, 3,
+ 1, 2, 3, 0,
},
shouldFail: true,
expectErrorICMP: true,
@@ -662,6 +555,56 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
+ // Timestamp pointer uses one based counting so 0 is invalid.
+ name: "timestamp pointer invalid",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, 0, 0x00,
+ // ^ 0 instead of 5 or more.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Timestamp pointer cannot be less than 5. It must point past the header
+ // which is 4 bytes. (1 based counting)
+ name: "timestamp pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength, 0x00,
+ // ^ header is 4 bytes, so 4 should fail.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "valid timestamp pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00,
+ // ^ header is 4 bytes, so 5 should succeed.
+ 0, 0, 0, 0,
+ },
+ replyOptions: header.IPv4Options{
+ 68, 8, 9, 0x00,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
// Needs 8 bytes for a type 1 timestamp but there are only 4 free.
name: "bad timer element alignment",
maxTotalLength: ipv4.MaxTotalSize,
@@ -792,7 +735,61 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
- // Confirm linux bug for bug compatibility.
+ // Pointer uses one based counting so 0 is invalid.
+ name: "record route pointer zero",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, 0, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. 3 should fail.
+ name: "record route pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. Check 4 passes. (Duplicates "single
+ // record route with room")
+ name: "valid record route pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: header.IPv4Options{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm Linux bug for bug compatibility.
// Linux returns slot 22 but the error is in slot 21.
name: "multiple record route with not enough room",
maxTotalLength: ipv4.MaxTotalSize,
@@ -863,8 +860,10 @@ func TestIPv4Sanity(t *testing.T) {
},
})
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
+ if len(test.options)%4 != 0 {
+ t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options))
+ }
+ ipHeaderLength := header.IPv4MinimumSize + len(test.options)
if ipHeaderLength > header.IPv4MaximumHeaderSize {
t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
@@ -883,11 +882,6 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
- // To check the padding works, poison the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0x01
- }
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
@@ -895,10 +889,19 @@ func TestIPv4Sanity(t *testing.T) {
TTL: test.TTL,
SrcAddr: remoteIPv4Addr,
DstAddr: ipv4Addr.Address,
- Options: test.options,
})
if test.headerLength != 0 {
ip.SetHeaderLength(test.headerLength)
+ } else {
+ // Set the calculated header length, since we may manually add options.
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ }
+ if len(test.options) != 0 {
+ // Copy options manually. We do not use Encode for options so we can
+ // verify malformed options with handcrafted payloads.
+ if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
+ t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
+ }
}
ip.SetChecksum(0)
ipHeaderChecksum := ip.CalculateChecksum()
@@ -1003,7 +1006,7 @@ func TestIPv4Sanity(t *testing.T) {
}
// If the IP options change size then the packet will change size, so
// some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - paddedOptionLength
+ sizeChange := len(test.replyOptions) - len(test.options)
checker.IPv4(t, replyIPHeader,
checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
@@ -2320,6 +2323,28 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
+ {
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: nil,
+ },
}
for _, test := range tests {
@@ -2513,7 +2538,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2530,7 +2555,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
@@ -2644,8 +2669,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2687,8 +2712,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2736,8 +2761,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != arp.ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress)
}
rep := header.ARP(p.Pkt.NetworkHeader().View())
if got := rep.Op(); got != header.ARPRequest {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 0ac24a6fb..afa45aefe 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -8,6 +8,7 @@ go_library(
"dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "mld.go",
"ndp.go",
],
visibility = ["//visibility:public"],
@@ -19,6 +20,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -49,3 +51,19 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "ipv6_x_test",
+ size = "small",
+ srcs = ["mld_test.go"],
+ deps = [
+ ":ipv6",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index beb8f562e..6ee162713 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := e.protocol.stack.Stats().ICMP
- sent := stats.V6PacketsSent
- received := stats.V6PacketsReceived
+ sent := stats.V6.PacketsSent
+ received := stats.V6.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
- switch h.Type() {
+ switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
@@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na := header.NDPNeighborAdvert(packet.MessageBody())
// As per RFC 4861 section 7.2.4:
//
@@ -644,8 +644,39 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
+ case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ received.MulticastListenerQuery.Increment()
+ case header.ICMPv6MulticastListenerReport:
+ received.MulticastListenerReport.Increment()
+ case header.ICMPv6MulticastListenerDone:
+ received.MulticastListenerDone.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerReport:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerDone:
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
default:
- received.Invalid.Increment()
+ received.Unrecognized.Increment()
}
}
@@ -681,12 +712,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- stat := p.stack.Stats().ICMP.V6PacketsSent
+ stat := p.stack.Stats().ICMP.V6.PacketsSent
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
@@ -796,7 +827,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
allowResponseToMulticast = reason.respondToMulticast
}
- if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any {
+ isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst)
+ if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any {
return nil
}
@@ -812,8 +844,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// If we are operating as a router, do not use the packet's destination
// address as the response's source address as we should not own the
// destination address of a packet we are forwarding.
+ //
+ // If the packet was originally destined to a multicast address, then do not
+ // use the packet's destination address as the source for the response ICMP
+ // packet as "multicast addresses must not be used as source addresses in IPv6
+ // packets", as per RFC 4291 section 2.7.
localAddr := origIPHdrDst
- if _, ok := reason.(*icmpReasonHopLimitExceeded); ok {
+ if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
localAddr = ""
}
// Even if we were able to receive a packet from some remote, we may not have
@@ -827,7 +864,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
defer route.Release()
stats := p.stack.Stats().ICMP
- sent := stats.V6PacketsSent
+ sent := stats.V6.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 9bc02d851..02b18e9a5 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -150,9 +150,9 @@ func (*testInterface) Promiscuous() bool {
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -271,6 +271,22 @@ func TestICMPCounts(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -280,11 +296,11 @@ func TestICMPCounts(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -301,7 +317,7 @@ func TestICMPCounts(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -413,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -422,11 +454,11 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -443,7 +475,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -568,8 +600,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
return
}
- if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
- t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr {
+ t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr)
}
// Pull the full payload since network header. Needed for header.IPv6 to
@@ -821,11 +853,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
@@ -833,7 +865,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -898,11 +930,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1016,11 +1048,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(icmpSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1028,7 +1060,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1076,11 +1108,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1195,11 +1227,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(size + payloadSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
@@ -1207,7 +1239,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1349,8 +1381,8 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
}
if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
@@ -1413,11 +1445,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1431,8 +1463,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1455,11 +1487,11 @@ func TestPacketQueing(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1473,8 +1505,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1524,8 +1556,8 @@ func TestPacketQueing(t *testing.T) {
t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1543,7 +1575,7 @@ func TestPacketQueing(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
@@ -1554,11 +1586,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1592,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1612,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1629,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1645,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1662,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1683,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1702,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1722,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1796,11 +1828,11 @@ func TestCallsToNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.source,
- DstAddr: test.destination,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
ep.HandlePacket(pkt)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 7a00f6314..a49b5ac77 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -19,6 +19,7 @@ import (
"encoding/binary"
"fmt"
"hash/fnv"
+ "math"
"sort"
"sync/atomic"
"time"
@@ -34,7 +35,9 @@ import (
)
const (
+ // ReassembleTimeout controls how long a fragment will be held.
// As per RFC 8200 section 4.5:
+ //
// If insufficient fragments are received to complete reassembly of a packet
// within 60 seconds of the reception of the first-arriving fragment of that
// packet, reassembly of that packet must be abandoned.
@@ -83,6 +86,7 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
+ mld mldState
}
}
@@ -118,6 +122,45 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// onAddressAssignedLocked handles an address being assigned.
+//
+// Precondition: e.mu must be exclusively locked.
+func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, ...
+ //
+ // If we just completed DAD for a link-local address, then attempt to send any
+ // queued MLD reports. Note, we may have sent reports already for some of the
+ // groups before we had a valid link-local address to use as the source for
+ // the MLD messages, but that was only so that MLD snooping switches are aware
+ // of our membership to groups - routers would not have handled those reports.
+ //
+ // As per RFC 3590 section 4,
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ if header.IsV6LinkLocalAddress(addr) {
+ e.mu.mld.sendQueuedReports()
+ }
+}
+
// InvalidateDefaultRouter implements stack.NDPEndpoint.
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
@@ -224,6 +267,12 @@ func (e *endpoint) Enable() *tcpip.Error {
return nil
}
+ // Groups may have been joined when the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of MLD when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.mld.initializeAll()
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -241,8 +290,10 @@ func (e *endpoint) Enable() *tcpip.Error {
// (NDP NS) messages may be sent to the All-Nodes multicast group if the
// source address of the NDP NS is the unspecified address, as per RFC 4861
// section 7.2.4.
- if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
- return err
+ if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err))
}
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
@@ -251,7 +302,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
var err *tcpip.Error
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
return true
@@ -273,7 +324,7 @@ func (e *endpoint) Enable() *tcpip.Error {
}
// Do not auto-generate an IPv6 link-local address for loopback devices.
- if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
// link-local address.
e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
@@ -322,7 +373,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.Enabled() {
return
}
@@ -331,9 +382,17 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
+
+ // Leave groups from the perspective of MLD so that routers know that
+ // we are no longer interested in the group.
+ e.mu.mld.softLeaveAll()
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -341,7 +400,7 @@ func (e *endpoint) disableLocked() {
// Precondition: e.mu must be write locked.
func (e *endpoint) stopDADForPermanentAddressesLocked() {
// Stop DAD for all the tentative unicast addresses.
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.GetKind() != stack.PermanentTentative {
return true
}
@@ -373,19 +432,27 @@ func (e *endpoint) MTU() uint32 {
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
+ // TODO(gvisor.dev/issues/5035): The maximum header length returned here does
+ // not open the possibility for the caller to know about size required for
+ // extension headers.
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) {
+ extHdrsLen := extensionHeaders.Length()
+ length := pkt.Size() + extensionHeaders.Length()
+ if length > math.MaxUint16 {
+ panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16))
+ }
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(params.Protocol),
- HopLimit: params.TTL,
- TrafficClass: params.TOS,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ PayloadLength: uint16(length),
+ TransportProtocol: params.Protocol,
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ ExtensionHeaders: extensionHeaders,
})
pkt.NetworkProtocolNumber = ProtocolNumber
}
@@ -440,7 +507,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -529,7 +596,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r, pb, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */)
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
@@ -737,8 +804,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
- addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
- if addressEndpoint == nil {
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ } else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
stats.IP.InvalidDestinationAddressesReceived.Increment()
return
@@ -747,7 +817,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
_ = e.forwardPacket(pkt)
return
}
- addressEndpoint.DecRef()
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
@@ -1090,9 +1159,16 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
+ // The location of the Next Header field is in a different place in
+ // the initial IPv6 header than it is in the extension headers so
+ // treat it specially.
+ prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset)
+ if previousHeaderStart != 0 {
+ prevHdrIDOffset = previousHeaderStart
+ }
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
+ pointer: prevHdrIDOffset,
}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
@@ -1100,12 +1176,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
default:
- _ = e.protocol.returnError(&icmpReasonParameterProblem{
- code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
- }, pkt)
- stats.UnknownProtocolRcvdPackets.Increment()
- return
+ // Since the iterator returns IPv6RawPayloadHeader for unknown Extension
+ // Header IDs this should never happen unless we missed a supported type
+ // here.
+ panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr))
+
}
}
}
@@ -1153,11 +1228,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
return addressEndpoint, nil
}
- snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
- return nil, err
- }
-
addressEndpoint.SetKind(stack.PermanentTentative)
if e.Enabled() {
@@ -1166,6 +1236,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
}
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
+ }
+
return addressEndpoint, nil
}
@@ -1211,7 +1288,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ // The endpoint may have already left the multicast group.
+ if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1234,7 +1312,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
//
// Precondition: e.mu must be read or write locked.
func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
- return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+ return e.mu.addressableEndpointState.GetAddress(localAddr)
}
// MainAddress implements stack.AddressableEndpoint.
@@ -1266,6 +1344,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow
return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
+// getLinkLocalAddressRLocked returns a link-local address from the primary list
+// of addresses, if one is available.
+//
+// See stack.PrimaryEndpointBehavior for more details about the primary list.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
+ var linkLocalAddr tcpip.Address
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.IsAssigned(false /* allowExpired */) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ linkLocalAddr = addr
+ return false
+ }
+ }
+ return true
+ })
+ return linkLocalAddr
+}
+
// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
// but with locking requirements.
//
@@ -1285,10 +1383,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// Create a candidate set of available addresses we can potentially use as a
// source address.
var cs []addrCandidate
- e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
// If r is not valid for outgoing connections, it is not a valid endpoint.
if !addressEndpoint.IsAssigned(allowExpired) {
- return
+ return true
}
addr := addressEndpoint.AddressWithPrefix().Address
@@ -1304,6 +1402,8 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
addressEndpoint: addressEndpoint,
scope: scope,
})
+
+ return true
})
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
@@ -1376,28 +1476,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mu.mld.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -1405,7 +1520,8 @@ var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
- stack *stack.Stack
+ stack *stack.Stack
+ options Options
mu struct {
sync.RWMutex
@@ -1429,26 +1545,6 @@ type protocol struct {
forwarding uint32
fragmentation *fragmentation.Fragmentation
-
- // ndpDisp is the NDP event dispatcher that is used to send the netstack
- // integrator NDP related events.
- ndpDisp NDPDispatcher
-
- // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
- ndpConfigs NDPConfigurations
-
- // opaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
- opaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
- // tempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- tempIIDSeed []byte
-
- // autoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
- autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -1481,16 +1577,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp = ndpState{
- ep: e,
- configs: p.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- e.mu.ndp.initializeTempAddrState()
+ e.mu.ndp.init(e)
+ e.mu.mld.init(e)
+ e.mu.Unlock()
p.mu.Lock()
defer p.mu.Unlock()
@@ -1613,17 +1704,17 @@ type Options struct {
// NDPConfigs is the default NDP configurations used by interfaces.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // AutoGenLinkLocal determines whether or not the stack attempts to
+ // auto-generate a link-local address for newly enabled non-loopback
// NICs.
//
// Note, setting this to true does not mean that a link-local address is
// assigned right away, or at all. If Duplicate Address Detection is enabled,
// an address is only assigned if it successfully resolves. If it fails, no
- // further attempts are made to auto-generate an IPv6 link-local adddress.
+ // further attempts are made to auto-generate a link-local adddress.
//
// The generated link-local address follows RFC 4291 Appendix A guidelines.
- AutoGenIPv6LinkLocal bool
+ AutoGenLinkLocal bool
// NDPDisp is the NDP event dispatcher that an integrator can provide to
// receive NDP related events.
@@ -1647,6 +1738,9 @@ type Options struct {
// seed that is too small would reduce randomness and increase predictability,
// defeating the purpose of temporary SLAAC addresses.
TempIIDSeed []byte
+
+ // MLD holds options for MLD.
+ MLD MLDOptions
}
// NewProtocolWithOptions returns an IPv6 network protocol.
@@ -1658,15 +1752,11 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
- stack: s,
+ stack: s,
+ options: opts,
+
ids: ids,
hashIV: hashIV,
-
- ndpDisp: opts.NDPDisp,
- ndpConfigs: opts.NDPConfigs,
- opaqueIIDOpts: opts.OpaqueIIDOpts,
- tempIIDSeed: opts.TempIIDSeed,
- autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
}
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[*endpoint]struct{})
@@ -1712,24 +1802,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
fragPkt.NetworkProtocolNumber = ProtocolNumber
originalIPHeadersLength := len(originalIPHeaders)
- fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+
+ s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ M: more,
+ Identification: id,
+ }}
+
+ fragmentIPHeadersLength := originalIPHeadersLength + s.Length()
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
- fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
}
- fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
- fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
- fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
- fragmentHeader.Encode(&header.IPv6FragmentFields{
- M: more,
- FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
- Identification: id,
- NextHeader: uint8(transportProto),
- })
+ nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:])
+
+ fragmentIPHeaders.SetNextHeader(nextHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
return fragPkt, more
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index a671d4bac..5f07d3af8 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -51,6 +51,7 @@ const (
fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+ unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier)
extraHeaderReserve = 50
)
@@ -68,18 +69,18 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
if got := stats.NeighborAdvert.Value(); got != want {
t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
@@ -126,11 +127,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -573,6 +574,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
expectICMP: false,
},
{
+ name: "unknown next header (first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, unknownHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6NextHeaderOffset,
+ },
+ {
+ name: "unknown next header (not first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ unknownHdrID, 0,
+ 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
name: "destination with unknown option skippable action",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -755,11 +783,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
pointer: header.IPv6FixedHeaderSize,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
- shouldAccept: false,
- },
- {
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -873,7 +896,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), udpPayload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
+
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength))
sum = header.Checksum(udpPayload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
@@ -884,16 +913,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- dstAddr := tcpip.Address(addr2)
- if test.multicast {
- dstAddr = header.IPv6AllNodesMulticastAddress
- }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
- NextHeader: ipv6NextHdr,
- HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: dstAddr,
+ // We're lying about transport protocol here to be able to generate
+ // raw extension headers from the test definitions.
+ TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr),
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -982,9 +1009,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
- udpPayload3Length = 127
- udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
- fragmentExtHdrLen = 8
+ udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
+ udpMaximumSizeMinus15 = header.UDPMaximumSize - 15
+ fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
routingExtHdrLen = 8
@@ -1328,14 +1356,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+65520,
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[:65520],
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
),
},
@@ -1344,14 +1372,17 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+ // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ udpMaximumSizeMinus15 & 0xff,
+ 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[65520:],
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
),
},
@@ -1359,6 +1390,47 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
{
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ (udpMaximumSizeMinus15 & 0xff) + 1,
+ 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -1877,10 +1949,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
+ // We're lying about transport protocol here so that we can generate
+ // raw extension headers for the tests.
+ TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr),
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
@@ -1925,7 +1999,7 @@ func TestInvalidIPv6Fragments(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -1944,14 +2018,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 9,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0 >> 3,
M: true,
Identification: ident,
@@ -1971,14 +2044,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
M: false,
Identification: ident,
@@ -2019,10 +2091,9 @@ func TestInvalidIPv6Fragments(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
-
- fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2084,7 +2155,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -2098,14 +2169,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2120,14 +2190,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2136,14 +2205,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2158,14 +2226,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2180,14 +2247,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2196,14 +2262,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2218,14 +2283,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2234,14 +2298,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2280,10 +2343,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2439,7 +2503,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2456,7 +2520,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
@@ -2924,11 +2988,11 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: test.TTL,
- SrcAddr: remoteIPv6Addr1,
- DstAddr: remoteIPv6Addr2,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
})
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
new file mode 100644
index 000000000..e8d1e7a79
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -0,0 +1,262 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited MLD reports.
+ //
+ // Obtained from RFC 2710 Section 7.10.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// MLDOptions holds options for MLD.
+type MLDOptions struct {
+ // Enabled indicates whether MLD will be performed.
+ //
+ // When enabled, MLD may transmit MLD report and done messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // MLD packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*mldState)(nil)
+
+// mldState is the per-interface MLD state.
+//
+// mldState.init MUST be called to initialize the MLD state.
+type mldState struct {
+ // The IPv6 endpoint this mldState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+func (mld *mldState) Enabled() bool {
+ // No need to perform MLD on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ return err
+}
+
+// init sets up an mldState struct, and is required to be called before using
+// a new mldState.
+//
+// Must only be called once for the lifetime of mld.
+func (mld *mldState) init(ep *endpoint) {
+ mld.ep = ep
+ mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: mld,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv6AllNodesMulticastAddress,
+ })
+}
+
+// handleMulticastListenerQuery handles a query message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+}
+
+// handleMulticastListenerReport handles a report message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress())
+}
+
+// joinGroup handles joining a new group and sending and scheduling the required
+// messages.
+//
+// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
+ return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Done message, if
+// required.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of MLD, but remains
+// joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) softLeaveAll() {
+ mld.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the MLD state for each group that has
+// been joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) initializeAll() {
+ mld.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) sendQueuedReports() {
+ mld.genericMulticastProtocol.SendQueuedReportsLocked()
+}
+
+// writePacket assembles and sends an MLD packet.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
+ sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ var mldStat *tcpip.StatCounter
+ switch mldType {
+ case header.ICMPv6MulticastListenerReport:
+ mldStat = sentStats.MulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ mldStat = sentStats.MulticastListenerDone
+ default:
+ panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
+ }
+
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
+ icmp.SetType(mldType)
+ header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert
+ // option in a Hop-by-Hop Options header.
+ //
+ // However, this would cause problems with Duplicate Address Detection with
+ // the first address as MLD snooping switches may not send multicast traffic
+ // that DAD depends on to the node performing DAD without the MLD report, as
+ // documented in RFC 4816:
+ //
+ // Note that when a node joins a multicast address, it typically sends a
+ // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810]
+ // for the multicast address. In the case of Duplicate Address
+ // Detection, the MLD report message is required in order to inform MLD-
+ // snooping switches, rather than routers, to forward multicast packets.
+ // In the above description, the delay for joining the multicast address
+ // thus means delaying transmission of the corresponding MLD report
+ // message. Since the MLD specifications do not request a random delay
+ // to avoid race conditions, just delaying Neighbor Solicitation would
+ // cause congestion by the MLD report messages. The congestion would
+ // then prevent the MLD-snooping switches from working correctly and, as
+ // a result, prevent Duplicate Address Detection from working. The
+ // requirement to include the delay for the MLD report in this case
+ // avoids this scenario. [RFC3590] also talks about some interaction
+ // issues between Duplicate Address Detection and MLD, and specifies
+ // which source address should be used for the MLD report in this case.
+ //
+ // As per RFC 3590 section 4, we should still send out MLD reports with an
+ // unspecified source address if we do not have an assigned link-local
+ // address to use as the source address to ensure DAD works as expected on
+ // networks with MLD snooping switches:
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ localAddress := mld.ep.getLinkLocalAddressRLocked()
+ if len(localAddress) == 0 {
+ localAddress = header.IPv6Any
+ }
+
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+
+ extensionHeaders := header.IPv6ExtHdrSerializer{
+ header.IPv6SerializableHopByHopExtHdr{
+ &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
+ },
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+
+ mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.MLDHopLimit,
+ }, extensionHeaders)
+ if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
+ }
+ mldStat.Increment()
+ return localAddress != header.IPv6Any, nil
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
new file mode 100644
index 000000000..e2778b656
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -0,0 +1,297 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+)
+
+var (
+ linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
+ globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
+)
+
+func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
+ t.Helper()
+
+ checker.IPv6WithExtHdr(t, p,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(localAddress),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(mldType, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // The stack will join an address's solicited node multicast address when
+ // an address is added. An MLD report message should be sent for the
+ // solicited-node group.
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+
+ // The stack will leave an address's solicited node multicast address when
+ // an address is removed. An MLD done message should be sent for the
+ // solicited-node group.
+ if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a done message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
+ }
+}
+
+func TestSendQueuedMLDReports(t *testing.T) {
+ const (
+ nicID = 1
+ maxReports = 2
+ )
+
+ tests := []struct {
+ name string
+ dadTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD Disabled",
+ dadTransmits: 0,
+ retransmitTimer: 0,
+ },
+ {
+ name: "DAD Enabled",
+ dadTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dadTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ Clock: clock,
+ })
+
+ // Allow space for an extra packet so we can observe packets that were
+ // unexpectedly sent.
+ e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ resolveDAD := func(addr, snmc tcpip.Address) {
+ clock.Advance(dadResolutionTime)
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected DAD packet")
+ } else {
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr),
+ checker.NDPNSOptions(nil),
+ ))
+ }
+ }
+
+ var reportCounter uint64
+ reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ var doneCounter uint64
+ doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+
+ // Joining a group without an assigned address should send an MLD report
+ // with the unspecified address.
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalMulticastAddr)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a global address should not send reports for the already joined
+ // group since we should only send queued reports when a link-local
+ // addres sis assigned.
+ //
+ // Note, we will still expect to send a report for the global address's
+ // solicited node address from the unspecified address as per RFC 3590
+ // section 4.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
+ }
+ if dadResolutionTime != 0 {
+ // Reports should not be sent when the address resolves.
+ resolveDAD(globalAddr, globalAddrSNMC)
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ }
+ // Leave the group since we don't care about the global address's
+ // solicited node multicast group membership.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
+ }
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a link-local address should send a report for its solicited node
+ // address and globalMulticastAddr.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ }
+ if dadResolutionTime != 0 {
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+ resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
+ }
+
+ // We expect two batches of reports to be sent (1 batch when the
+ // link-local address is assigned, and another after the maximum
+ // unsolicited report interval.
+ for i := 0; i < 2; i++ {
+ // We expect reports to be sent (one for globalMulticastAddr and another
+ // for linkLocalAddrSNMC).
+ reportCounter += maxReports
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+
+ addrs := map[tcpip.Address]bool{
+ globalMulticastAddr: false,
+ linkLocalAddrSNMC: false,
+ }
+ for _ = range addrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
+ }
+
+ addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
+ if seen, ok := addrs[addr]; !ok {
+ t.Fatalf("got unexpected packet destined to %s", addr)
+ } else if seen {
+ t.Fatalf("got another packet destined to %s", addr)
+ }
+
+ addrs[addr] = true
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
+
+ clock.Advance(ipv6.UnsolicitedReportIntervalMax)
+ }
+ }
+
+ // Should not send any more reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 40da011f8..d515eb622 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -20,6 +20,7 @@ import (
"math/rand"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
// The IPv6 endpoint this ndpState is for.
ep *endpoint
@@ -471,17 +475,8 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- rtrSolicit struct {
- // The timer used to send the next router solicitation message.
- timer tcpip.Timer
-
- // Used to let the Router Solicitation timer know that it has been stopped.
- //
- // Must only be read from or written to while protected by the lock of
- // the IPv6 endpoint this ndpState is associated with. MUST be set when the
- // timer is set.
- done *bool
- }
+ // The job used to send the next router solicitation message.
+ rtrSolicitJob *tcpip.Job
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -507,7 +502,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer tcpip.Timer
+ job *tcpip.Job
// Used to let the DAD timer know that it has been stopped.
//
@@ -648,96 +643,73 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
+ ndp.ep.onAddressAssignedLocked(addr)
return nil
}
- var done bool
- var timer tcpip.Timer
- // We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the IPv6 endpoint's lock. This is effectively
- // the same as starting a goroutine but we use a timer that fires immediately
- // so we can reset it for the next DAD iteration.
- timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
- ndp.ep.mu.Lock()
- defer ndp.ep.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the IPv6 endpoint lock and stopped
- // DAD before this function obtained the NIC lock. Simply return here and
- // do nothing further.
- return
- }
+ state := dadState{
+ job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.dad[addr]
+ if !ok {
+ panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- // The endpoint should still be marked as tentative since we are still
- // performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- dadDone := remaining == 0
-
- var err *tcpip.Error
- if !dadDone {
- // Use the unspecified address as the source address when performing DAD.
- addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
-
- // Do not hold the lock when sending packets which may be a long running
- // task or may block link address resolution. We know this is safe
- // because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the IPv6 endpoint. Note,
- // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
- // the address was removed.
- ndp.ep.mu.Unlock()
- err = ndp.sendDADPacket(addr, addressEndpoint)
- ndp.ep.mu.Lock()
- addressEndpoint.DecRef()
- }
+ dadDone := remaining == 0
- if done {
- // If we reach this point, it means that DAD was stopped after we released
- // the IPv6 endpoint's read lock and before we obtained the write lock.
- return
- }
+ var err *tcpip.Error
+ if !dadDone {
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ }
- if dadDone {
- // DAD has resolved.
- addressEndpoint.SetKind(stack.Permanent)
- } else if err == nil {
- // DAD is not done and we had no errors when sending the last NDP NS,
- // schedule the next DAD timer.
- remaining--
- timer.Reset(ndp.configs.RetransmitTimer)
- return
- }
+ if dadDone {
+ // DAD has resolved.
+ addressEndpoint.SetKind(stack.Permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ state.job.Schedule(ndp.configs.RetransmitTimer)
+ return
+ }
- // At this point we know that either DAD is done or we hit an error sending
- // the last NDP NS. Either way, clean up addr's DAD state and let the
- // integrator know DAD has completed.
- delete(ndp.dad, addr)
+ // At this point we know that either DAD is done or we hit an error
+ // sending the last NDP NS. Either way, clean up addr's DAD state and let
+ // the integrator know DAD has completed.
+ delete(ndp.dad, addr)
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
- }
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ }
- // If DAD resolved for a stable SLAAC address, attempt generation of a
- // temporary SLAAC address.
- if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the generation
- // of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
- }
- })
+ if dadDone {
+ if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the
+ // generation of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
- ndp.dad[addr] = dadState{
- timer: timer,
- done: &done,
+ ndp.ep.onAddressAssignedLocked(addr)
+ }
+ }),
}
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ state.job.Schedule(0)
+ ndp.dad[addr] = state
+
return nil
}
@@ -745,55 +717,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-//
-// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer r.Release()
-
- // Route should resolve immediately since snmc is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is not
- // locked, the NIC could have been disabled or removed by another goroutine.
- if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
- return err
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
- }
-
- icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
- icmpData.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(addr)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(icmpData).ToVectorisedView(),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
return err
}
sent.NeighborSolicit.Increment()
-
return nil
}
@@ -812,18 +760,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
return
}
- if dad.timer != nil {
- dad.timer.Stop()
- dad.timer = nil
-
- *dad.done = true
- dad.done = nil
- }
-
+ dad.job.Cancel()
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
@@ -846,7 +787,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
// only inform the dispatcher on configuration changes. We do nothing else
// with the information.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
var configuration DHCPv6ConfigurationFromNDPRA
switch {
case ra.ManagedAddrConfFlag():
@@ -903,20 +844,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
switch opt := opt.(type) {
case header.NDPRecursiveDNSServer:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
addrs, _ := opt.Addresses()
- ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
domainNames, _ := opt.DomainNames()
- ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -964,7 +905,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
@@ -976,7 +917,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1006,7 +947,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1047,7 +988,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1225,7 +1166,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
// Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return nil
}
@@ -1272,7 +1213,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
}
dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
- if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
@@ -1676,7 +1617,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint
}
addressEndpoint.SetDeprecated(true)
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
@@ -1701,7 +1642,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1761,7 +1702,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1859,7 +1800,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicit.timer != nil {
+ if ndp.rtrSolicitJob != nil {
// We are already soliciting routers.
return
}
@@ -1876,56 +1817,14 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- var done bool
- ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
- ndp.ep.mu.Lock()
- if done {
- // If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the IPv6 endpoint lock and stopped
- // solicitations. Simply return here and do nothing further.
- ndp.ep.mu.Unlock()
- return
- }
-
+ ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
- if addressEndpoint == nil {
- // Incase this ends up creating a new temporary address, we need to hold
- // onto the endpoint until a route is obtained. If we decrement the
- // reference count before obtaing a route, the address's resources would
- // be released and attempting to obtain a route after would fail. Once a
- // route is obtainted, it is safe to decrement the reference count since
- // obtaining a route increments the address's reference count.
- addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
- }
- ndp.ep.mu.Unlock()
-
- localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
- addressEndpoint.DecRef()
- if err != nil {
- return
- }
- defer r.Release()
-
- // Route should resolve immediately since
- // header.IPv6AllRoutersMulticastAddress is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is
- // not locked, the IPv6 endpoint could have been disabled or removed by
- // another goroutine.
- if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
- return
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
+ localAddr := header.IPv6Any
+ if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1936,30 +1835,31 @@ func (ndp *ndpState) startSolicitingRouters() {
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
var optsSerializer header.NDPOptionsSerializer
- if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ linkAddress := ndp.ep.nic.LinkAddress()
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) {
optsSerializer = header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ header.NDPSourceLinkLayerAddressOption(linkAddress),
}
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs := header.NDPRouterSolicit(icmpData.MessageBody())
rs.Options().Serialize(optsSerializer)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
@@ -1969,21 +1869,12 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.ep.mu.Lock()
- if done || remaining == 0 {
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
- } else if ndp.rtrSolicit.timer != nil {
- // Note, we need to explicitly check to make sure that
- // the timer field is not nil because if it was nil but
- // we still reached this point, then we know the IPv6 endpoint
- // was requested to stop soliciting routers so we don't
- // need to send the next Router Solicitation message.
- ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ if remaining != 0 {
+ ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval)
}
- ndp.ep.mu.Unlock()
})
+ ndp.rtrSolicitJob.Schedule(delay)
}
// stopSolicitingRouters stops soliciting routers. If routers are not currently
@@ -1991,22 +1882,28 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicit.timer == nil {
+ if ndp.rtrSolicitJob == nil {
// Nothing to do.
return
}
- *ndp.rtrSolicit.done = true
- ndp.rtrSolicit.timer.Stop()
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
+ ndp.rtrSolicitJob.Cancel()
+ ndp.rtrSolicitJob = nil
}
-// initializeTempAddrState initializes state related to temporary SLAAC
-// addresses.
-func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
+func (ndp *ndpState) init(ep *endpoint) {
+ if ndp.dad != nil {
+ panic("attempted to initialize NDP state twice")
+ }
+
+ ndp.ep = ep
+ ndp.configs = ep.protocol.options.NDPConfigs
+ ndp.dad = make(map[tcpip.Address]dadState)
+ ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
+ ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 37e8b1083..05a0d95b2 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -213,14 +213,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -319,23 +319,23 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(nicAddr)
opts := ns.Options()
opts.Serialize(test.nsOpts)
@@ -599,14 +599,14 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != respNSDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
}
- if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(test.nsSrc)
@@ -681,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
})
e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != test.naDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
}
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -785,14 +785,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -898,23 +898,23 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -979,29 +979,25 @@ func TestNDPValidation(t *testing.T) {
}
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
+ var extHdrs header.IPv6ExtHdrSerializer
if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
}
+ extHdrsLen := extHdrs.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen,
Data: payload.ToVectorisedView(),
})
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(payload) + extHdrsLen),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: hopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ ExtensionHeaders: extHdrs,
})
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
ep.HandlePacket(pkt)
}
@@ -1122,7 +1118,7 @@ func TestNDPValidation(t *testing.T) {
s.SetForwarding(ProtocolNumber, true)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -1346,19 +1342,19 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
+ copy(pkt.MessageBody(), test.ndpPayload)
payloadLength := hdr.UsedLength()
pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
rxRA := stats.RouterAdvert
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
new file mode 100644
index 000000000..05d98a0a5
--- /dev/null
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -0,0 +1,1261 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+ ipv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+
+ ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
+ ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
+ ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
+ ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
+ ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
+
+ igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
+ igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
+ igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
+ igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
+ mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
+ mldReport = uint8(header.ICMPv6MulticastListenerReport)
+ mldDone = uint8(header.ICMPv6MulticastListenerDone)
+
+ maxUnsolicitedReports = 2
+)
+
+var (
+ // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
+ // NIC will wait before sending an unsolicited report after joining a
+ // multicast group, in deciseconds.
+ unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
+ const decisecond = time.Second / 10
+ if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
+ panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
+ }
+ return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
+ }()
+
+ ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr)
+)
+
+// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
+// sent to the provided address with the passed fields set.
+func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv6WithExtHdr(t, payload,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(ipv6Addr),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set.
+func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.SrcAddr(ipv4Addr),
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(header.IGMPType(igmpType)),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
+ s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
+ return e, s, clock
+}
+
+func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ igmpEnabled := v4 && mgpEnabled
+ mldEnabled := !v4 && mgpEnabled
+
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: mldEnabled,
+ },
+ }),
+ },
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err)
+ }
+
+ return s, clock
+}
+
+// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
+// when it is created with an IPv6 address.
+//
+// To not interfere with tests, checkInitialIPv6Groups will leave the added
+// address's solicited node multicast group so that the tests can all assume
+// the NIC has not joined any IPv6 groups.
+func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
+ t.Helper()
+
+ stats := s.Stats().ICMP.V6.PacketsSent
+
+ reportCounter++
+ if got := stats.MulticastListenerReport.Value(); got != reportCounter {
+ t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
+ }
+
+ // Leave the group to not affect the tests. This is fine since we are not
+ // testing DAD or the solicited node address specifically.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
+ }
+ leaveCounter++
+ if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
+ t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ return reportCounter, leaveCounter
+}
+
+// createAndInjectIGMPPacket creates and injects an IGMP packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: header.IGMPTTL,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(header.IGMPType(igmpType))
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// createAndInjectMLDPacket creates and injects an MLD packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
+ icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ buf := buffer.NewView(header.IPv6MinimumSize + icmpSize)
+
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ HopLimit: header.MLDHopLimit,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ icmp := header.ICMPv6(buf[header.IPv6MinimumSize:])
+ icmp.SetType(header.ICMPv6Type(mldType))
+ mld := header.MLD(icmp.MessageBody())
+ mld.SetMaximumResponseDelay(uint16(maxRespDelay))
+ mld.SetMulticastAddress(groupAddress)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+
+ e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestMGPDisabled tests that the multicast group protocol is not enabled by
+// default.
+func TestMGPDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
+
+ // This NIC may join multicast groups when it is enabled but since MGP is
+ // disabled, no reports should be sent.
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
+ }
+
+ // Inject a general query message. This should only trigger a report to be
+ // sent if the MGP was enabled.
+ test.rxQuery(e)
+ if got := test.receivedQueryStat(s).Value(); got != 1 {
+ t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
+ }
+ })
+ }
+}
+
+func TestMGPReceiveCounters(t *testing.T) {
+ tests := []struct {
+ name string
+ headerType uint8
+ maxRespTime byte
+ groupAddress tcpip.Address
+ statCounter func(*stack.Stack) *tcpip.StatCounter
+ rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
+ }{
+ {
+ name: "IGMP Membership Query",
+ headerType: igmpMembershipQuery,
+ maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
+ groupAddress: header.IPv4Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv1 Membership Report",
+ headerType: igmpv1MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V1MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv2 Membership Report",
+ headerType: igmpv2MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V2MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMP Leave Group",
+ headerType: igmpLeaveGroup,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllRoutersGroup,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.LeaveGroup
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "MLD Query",
+ headerType: mldQuery,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Report",
+ headerType: mldReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Done",
+ headerType: mldDone,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
+
+ test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
+ if got := test.statCounter(s).Value(); got != 1 {
+ t.Fatalf("got %s received = %d, want = 1", test.name, got)
+ }
+ })
+ }
+}
+
+// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
+// stack schedules and sends correct Membership Reports.
+func TestMGPJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ // Test joining a specific address explicitly and verify a Report is sent
+ // immediately.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Verify the second report is sent by the maximum unsolicited response
+ // interval.
+ p, ok := e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPLeaveGroup tests that when leaving a previously joined multicast
+// group the stack sends a leave/done message.
+func TestMGPLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ if got := test.sentReportStat(s).Value(); got != reportCounter {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving the group should trigger an leave/done message to be sent.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ leaveCounter++
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a leave message to be sent")
+ } else {
+ test.validateLeave(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that a report is sent in response to query
+// messages.
+func TestMGPQueryMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ multicastAddr tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Unspecified",
+ multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
+ expectReport: true,
+ },
+ {
+ name: "Specified",
+ multicastAddr: test.multicastAddr,
+ expectReport: true,
+ },
+ {
+ name: "Specified other address",
+ multicastAddr: func() tcpip.Address {
+ addrBytes := []byte(test.multicastAddr)
+ addrBytes[len(addrBytes)-1]++
+ return tcpip.Address(addrBytes)
+ }(),
+ expectReport: false,
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ for i := 0; i < maxUnsolicitedReports; i++ {
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected %d-th report message to be sent", i)
+ } else {
+ test.validateReport(t, p)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should not send any more packets until a query.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ // Receive a query message which should trigger a report to be sent at
+ // some time before the maximum response time if the report is
+ // targeted at the host.
+ const maxRespTime = 100
+ test.rxQuery(e, maxRespTime, subTest.multicastAddr)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p.Pkt)
+ }
+
+ if subTest.expectReport {
+ clock.Advance(test.maxRespTimeToDuration(maxRespTime))
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that no further reports or leave/done messages
+// are sent after receiving a report.
+func TestMGPReportMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ rxReport func(*channel.Endpoint)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Receiving a report for a group we joined should cancel any further
+ // reports.
+ test.rxReport(e)
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("sent unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving a group after getting a report should not send a leave/done
+ // message.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ clock.Advance(time.Hour)
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+func TestMGPWithNICLifecycle(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddrs []tcpip.Address
+ finalMulticastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
+ validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
+ getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
+ finalMulticastAddr: ipv4MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
+ t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
+ }
+ addr := header.IGMP(ipv4.Payload()).GroupAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
+ finalMulticastAddr: ipv6MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, addr, mldReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+
+ ipv6HeaderIter := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var transport header.IPv6RawPayloadHeader
+ for {
+ h, done, err := ipv6HeaderIter.Next()
+ if err != nil {
+ t.Fatalf("ipv6HeaderIter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ }
+ if t, ok := h.(header.IPv6RawPayloadHeader); ok {
+ transport = t
+ break
+ }
+ }
+
+ if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
+ t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
+ }
+ icmpv6 := header.ICMPv6(transport.Buf.ToView())
+ if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
+ t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
+ }
+ addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ sentReportStat := test.sentReportStat(s)
+ for _, a := range test.multicastAddrs {
+ if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected a report message to be sent for %s", a)
+ } else {
+ test.validateReport(t, p, a)
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leave messages should be sent for the joined groups when the NIC is
+ // disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ sentLeaveStat := test.sentLeaveStat(s)
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+
+ test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Reports should be sent for the joined groups when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter += uint64(len(test.multicastAddrs))
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) report message to be sent", i)
+ }
+
+ test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Joining/leaving a group while disabled should not send any messages.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ for i, _ := range test.multicastAddrs {
+ if _, ok := e.Read(); !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+ }
+ for _, a := range test.multicastAddrs {
+ if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
+ }
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
+ }
+ }
+ if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
+ }
+
+ // A report should only be sent for the group we last joined after
+ // enabling the NIC since the original groups were all left.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
+// performed on loopback interfaces since they have no neighbours.
+func TestMGPDisabledOnLoopback(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
+
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 7cc52985e..5c3363759 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st
return n, nil
}
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
-
- return nil
-}
-
// Attach implements LinkEndpoint.Attach.
func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 2a6c7c7c0..b60a5fd76 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -15,31 +15,350 @@
package tcpip
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/sync"
)
-// SocketOptions contains all the variables which store values for socket
-// level options.
+// SocketOptionsHandler holds methods that help define endpoint specific
+// behavior for socket level socket options. These must be implemented by
+// endpoints to get notified when socket level options are set.
+type SocketOptionsHandler interface {
+ // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint.
+ OnReuseAddressSet(v bool)
+
+ // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint.
+ OnReusePortSet(v bool)
+
+ // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint.
+ OnKeepAliveSet(v bool)
+
+ // OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint.
+ // Note that v will be the inverse of TCP_NODELAY option.
+ OnDelayOptionSet(v bool)
+
+ // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint.
+ OnCorkOptionSet(v bool)
+
+ // LastError is invoked when SO_ERROR is read for an endpoint.
+ LastError() *Error
+}
+
+// DefaultSocketOptionsHandler is an embeddable type that implements no-op
+// implementations for SocketOptionsHandler methods.
+type DefaultSocketOptionsHandler struct{}
+
+var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil)
+
+// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet.
+func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {}
+
+// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet.
+func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {}
+
+// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet.
+func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {}
+
+// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet.
+func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
+
+// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet.
+func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
+
+// LastError implements SocketOptionsHandler.LastError.
+func (*DefaultSocketOptionsHandler) LastError() *Error {
+ return nil
+}
+
+// SocketOptions contains all the variables which store values for SOL_SOCKET,
+// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
// +stateify savable
type SocketOptions struct {
- // mu protects fields below.
- mu sync.Mutex `state:"nosave"`
- broadcastEnabled bool
+ handler SocketOptionsHandler
+
+ // These fields are accessed and modified using atomic operations.
+
+ // broadcastEnabled determines whether datagram sockets are allowed to
+ // send packets to a broadcast address.
+ broadcastEnabled uint32
+
+ // passCredEnabled determines whether SCM_CREDENTIALS socket control
+ // messages are enabled.
+ passCredEnabled uint32
+
+ // noChecksumEnabled determines whether UDP checksum is disabled while
+ // transmitting for this socket.
+ noChecksumEnabled uint32
+
+ // reuseAddressEnabled determines whether Bind() should allow reuse of
+ // local address.
+ reuseAddressEnabled uint32
+
+ // reusePortEnabled determines whether to permit multiple sockets to be
+ // bound to an identical socket address.
+ reusePortEnabled uint32
+
+ // keepAliveEnabled determines whether TCP keepalive is enabled for this
+ // socket.
+ keepAliveEnabled uint32
+
+ // multicastLoopEnabled determines whether multicast packets sent over a
+ // non-loopback interface will be looped back. Analogous to inet->mc_loop.
+ multicastLoopEnabled uint32
+
+ // receiveTOSEnabled is used to specify if the TOS ancillary message is
+ // passed with incoming packets.
+ receiveTOSEnabled uint32
+
+ // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary
+ // message is passed with incoming packets.
+ receiveTClassEnabled uint32
+
+ // receivePacketInfoEnabled is used to specify if more inforamtion is
+ // provided with incoming packets such as interface index and address.
+ receivePacketInfoEnabled uint32
+
+ // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
+ // being written have an IP header and the endpoint should not attach an IP
+ // header.
+ hdrIncludedEnabled uint32
+
+ // v6OnlyEnabled is used to determine whether an IPv6 socket is to be
+ // restricted to sending and receiving IPv6 packets only.
+ v6OnlyEnabled uint32
+
+ // quickAckEnabled is used to represent the value of TCP_QUICKACK option.
+ // It currently does not have any effect on the TCP endpoint.
+ quickAckEnabled uint32
+
+ // delayOptionEnabled is used to specify if data should be sent out immediately
+ // by the transport protocol. For TCP, it determines if the Nagle algorithm
+ // is on or off.
+ delayOptionEnabled uint32
+
+ // corkOptionEnabled is used to specify if data should be held until segments
+ // are full by the TCP transport protocol.
+ corkOptionEnabled uint32
+
+ // receiveOriginalDstAddress is used to specify if the original destination of
+ // the incoming packet should be returned as an ancillary message.
+ receiveOriginalDstAddress uint32
+
+ // mu protects the access to the below fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // linger determines the amount of time the socket should linger before
+ // close. We currently implement this option for TCP socket only.
+ linger LingerOption
+}
+
+// InitHandler initializes the handler. This must be called before using the
+// socket options utility.
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) {
+ so.handler = handler
+}
+
+func storeAtomicBool(addr *uint32, v bool) {
+ var val uint32
+ if v {
+ val = 1
+ }
+ atomic.StoreUint32(addr, val)
}
// GetBroadcast gets value for SO_BROADCAST option.
func (so *SocketOptions) GetBroadcast() bool {
- so.mu.Lock()
- defer so.mu.Unlock()
-
- return so.broadcastEnabled
+ return atomic.LoadUint32(&so.broadcastEnabled) != 0
}
// SetBroadcast sets value for SO_BROADCAST option.
func (so *SocketOptions) SetBroadcast(v bool) {
+ storeAtomicBool(&so.broadcastEnabled, v)
+}
+
+// GetPassCred gets value for SO_PASSCRED option.
+func (so *SocketOptions) GetPassCred() bool {
+ return atomic.LoadUint32(&so.passCredEnabled) != 0
+}
+
+// SetPassCred sets value for SO_PASSCRED option.
+func (so *SocketOptions) SetPassCred(v bool) {
+ storeAtomicBool(&so.passCredEnabled, v)
+}
+
+// GetNoChecksum gets value for SO_NO_CHECK option.
+func (so *SocketOptions) GetNoChecksum() bool {
+ return atomic.LoadUint32(&so.noChecksumEnabled) != 0
+}
+
+// SetNoChecksum sets value for SO_NO_CHECK option.
+func (so *SocketOptions) SetNoChecksum(v bool) {
+ storeAtomicBool(&so.noChecksumEnabled, v)
+}
+
+// GetReuseAddress gets value for SO_REUSEADDR option.
+func (so *SocketOptions) GetReuseAddress() bool {
+ return atomic.LoadUint32(&so.reuseAddressEnabled) != 0
+}
+
+// SetReuseAddress sets value for SO_REUSEADDR option.
+func (so *SocketOptions) SetReuseAddress(v bool) {
+ storeAtomicBool(&so.reuseAddressEnabled, v)
+ so.handler.OnReuseAddressSet(v)
+}
+
+// GetReusePort gets value for SO_REUSEPORT option.
+func (so *SocketOptions) GetReusePort() bool {
+ return atomic.LoadUint32(&so.reusePortEnabled) != 0
+}
+
+// SetReusePort sets value for SO_REUSEPORT option.
+func (so *SocketOptions) SetReusePort(v bool) {
+ storeAtomicBool(&so.reusePortEnabled, v)
+ so.handler.OnReusePortSet(v)
+}
+
+// GetKeepAlive gets value for SO_KEEPALIVE option.
+func (so *SocketOptions) GetKeepAlive() bool {
+ return atomic.LoadUint32(&so.keepAliveEnabled) != 0
+}
+
+// SetKeepAlive sets value for SO_KEEPALIVE option.
+func (so *SocketOptions) SetKeepAlive(v bool) {
+ storeAtomicBool(&so.keepAliveEnabled, v)
+ so.handler.OnKeepAliveSet(v)
+}
+
+// GetMulticastLoop gets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) GetMulticastLoop() bool {
+ return atomic.LoadUint32(&so.multicastLoopEnabled) != 0
+}
+
+// SetMulticastLoop sets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) SetMulticastLoop(v bool) {
+ storeAtomicBool(&so.multicastLoopEnabled, v)
+}
+
+// GetReceiveTOS gets value for IP_RECVTOS option.
+func (so *SocketOptions) GetReceiveTOS() bool {
+ return atomic.LoadUint32(&so.receiveTOSEnabled) != 0
+}
+
+// SetReceiveTOS sets value for IP_RECVTOS option.
+func (so *SocketOptions) SetReceiveTOS(v bool) {
+ storeAtomicBool(&so.receiveTOSEnabled, v)
+}
+
+// GetReceiveTClass gets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) GetReceiveTClass() bool {
+ return atomic.LoadUint32(&so.receiveTClassEnabled) != 0
+}
+
+// SetReceiveTClass sets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) SetReceiveTClass(v bool) {
+ storeAtomicBool(&so.receiveTClassEnabled, v)
+}
+
+// GetReceivePacketInfo gets value for IP_PKTINFO option.
+func (so *SocketOptions) GetReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0
+}
+
+// SetReceivePacketInfo sets value for IP_PKTINFO option.
+func (so *SocketOptions) SetReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receivePacketInfoEnabled, v)
+}
+
+// GetHeaderIncluded gets value for IP_HDRINCL option.
+func (so *SocketOptions) GetHeaderIncluded() bool {
+ return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
+}
+
+// SetHeaderIncluded sets value for IP_HDRINCL option.
+func (so *SocketOptions) SetHeaderIncluded(v bool) {
+ storeAtomicBool(&so.hdrIncludedEnabled, v)
+}
+
+// GetV6Only gets value for IPV6_V6ONLY option.
+func (so *SocketOptions) GetV6Only() bool {
+ return atomic.LoadUint32(&so.v6OnlyEnabled) != 0
+}
+
+// SetV6Only sets value for IPV6_V6ONLY option.
+//
+// Preconditions: the backing TCP or UDP endpoint must be in initial state.
+func (so *SocketOptions) SetV6Only(v bool) {
+ storeAtomicBool(&so.v6OnlyEnabled, v)
+}
+
+// GetQuickAck gets value for TCP_QUICKACK option.
+func (so *SocketOptions) GetQuickAck() bool {
+ return atomic.LoadUint32(&so.quickAckEnabled) != 0
+}
+
+// SetQuickAck sets value for TCP_QUICKACK option.
+func (so *SocketOptions) SetQuickAck(v bool) {
+ storeAtomicBool(&so.quickAckEnabled, v)
+}
+
+// GetDelayOption gets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) GetDelayOption() bool {
+ return atomic.LoadUint32(&so.delayOptionEnabled) != 0
+}
+
+// SetDelayOption sets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) SetDelayOption(v bool) {
+ storeAtomicBool(&so.delayOptionEnabled, v)
+ so.handler.OnDelayOptionSet(v)
+}
+
+// GetCorkOption gets value for TCP_CORK option.
+func (so *SocketOptions) GetCorkOption() bool {
+ return atomic.LoadUint32(&so.corkOptionEnabled) != 0
+}
+
+// SetCorkOption sets value for TCP_CORK option.
+func (so *SocketOptions) SetCorkOption(v bool) {
+ storeAtomicBool(&so.corkOptionEnabled, v)
+ so.handler.OnCorkOptionSet(v)
+}
+
+// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) GetReceiveOriginalDstAddress() bool {
+ return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0
+}
+
+// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
+ storeAtomicBool(&so.receiveOriginalDstAddress, v)
+}
+
+// GetLastError gets value for SO_ERROR option.
+func (so *SocketOptions) GetLastError() *Error {
+ return so.handler.LastError()
+}
+
+// GetOutOfBandInline gets value for SO_OOBINLINE option.
+func (*SocketOptions) GetOutOfBandInline() bool {
+ return true
+}
+
+// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not
+// support disabling this option.
+func (*SocketOptions) SetOutOfBandInline(bool) {}
+
+// GetLinger gets value for SO_LINGER option.
+func (so *SocketOptions) GetLinger() LingerOption {
so.mu.Lock()
- defer so.mu.Unlock()
+ linger := so.linger
+ so.mu.Unlock()
+ return linger
+}
- so.broadcastEnabled = v
+// SetLinger sets value for SO_LINGER option.
+func (so *SocketOptions) SetLinger(linger LingerOption) {
+ so.mu.Lock()
+ so.linger = linger
+ so.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index d09ebe7fa..9cc6074da 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -112,7 +112,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
- shard_count = 20,
+ shard_count = most_shards,
deps = [
":stack",
"//pkg/rand",
@@ -120,6 +120,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
@@ -131,7 +132,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 9478f3fb7..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
-}
-
-// ReadOnlyAddressableEndpointState provides read-only access to an
-// AddressableEndpointState.
-type ReadOnlyAddressableEndpointState struct {
- inner *AddressableEndpointState
}
-// AddrOrMatching returns an endpoint for the passed address that is consisdered
-// bound to the wrapped AddressableEndpointState.
+// GetAddress returns the AddressEndpoint for the passed address.
//
-// If addr is an exact match with an existing address, that address is returned.
-// Otherwise, f is called with each address and the address that f returns true
-// for is returned.
-//
-// Returns nil of no address matches.
-func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
-
- if ep, ok := m.inner.mu.endpoints[addr]; ok {
- if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
- return ep
- }
- }
-
- for _, ep := range m.inner.mu.endpoints {
- if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
- return ep
- }
- }
-
- return nil
-}
-
-// Lookup returns the AddressEndpoint for the passed address.
+// GetAddress does not increment the address's reference count or check if the
+// address is considered bound to the endpoint.
//
-// Returns nil if the passed address is not associated with the
-// AddressableEndpointState.
-func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Returns nil if the passed address is not associated with the endpoint.
+func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- ep, ok := m.inner.mu.endpoints[addr]
+ ep, ok := a.mu.endpoints[addr]
if !ok {
return nil
}
return ep
}
-// ForEach calls f for each address pair.
+// ForEachEndpoint calls f for each address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- for _, ep := range m.inner.mu.endpoints {
+ for _, ep := range a.mu.endpoints {
if !f(ep) {
return
}
@@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool)
// ForEachPrimaryEndpoint calls f for each primary address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
- for _, ep := range m.inner.mu.primary {
- f(ep)
- }
-}
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
-// ReadOnly returns a readonly reference to a.
-func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
- return ReadOnlyAddressableEndpointState{inner: a}
+ for _, ep := range a.mu.primary {
+ if !f(ep) {
+ return
+ }
+ }
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
@@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
return deprecatedEndpoint
}
-// AcquireAssignedAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+// AcquireAssignedAddressOrMatching returns an address endpoint that is
+// considered assigned to the addressable endpoint.
+//
+// If the address is an exact match with an existing address, that address is
+// returned. Otherwise, if f is provided, f is called with each address and
+// the address that f returns true for is returned.
+//
+// If there is no matching address, a temporary address will be returned if
+// allowTemp is true.
+//
+// Regardless how the address was obtained, it will be acquired before it is
+// returned.
+func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
@@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return addrState
}
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
+ }
+ }
+
if !allowTemp {
return nil
}
@@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB)
+}
+
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
@@ -588,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
- if err != nil {
- return false, err
- }
- // We have no need for the address endpoint.
- a.decAddressRefLocked(ep)
- }
-
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- a.removeGroupAddressLocked(group)
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
-func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
- if err := a.removePermanentAddressLocked(group); err != nil {
- // removePermanentEndpointLocked would only return an error if group is
- // not bound to the addressable endpoint, but we know it MUST be assigned
- // since we have group in our map of groups.
- panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
- }
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- for group := range a.mu.groups {
- a.removeGroupAddressLocked(group)
- }
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 6dc9e7859..5ec9b3411 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -309,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
@@ -333,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := fwdTestPacketInfo{
- Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 73a01c2dd..03d7b4e0d 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) {
}
// We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) {
}
// Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
@@ -533,8 +540,8 @@ func TestDADResolve(t *testing.T) {
// Make sure the right remote link address is used.
snmc := header.SolicitedNodeAddr(addr1)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
}
// Check NDP NS packet.
@@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
}
@@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
@@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
@@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate an address conflict.
test.rxPkt(e, addr1)
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
@@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) {
}
// Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
t.Errorf("got NeighborSolicit = %d, want <= 1", got)
}
})
@@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
+ raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
binary.BigEndian.PutUint16(raPayload[2:], rl)
@@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
return stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
NDPConfigs: ipv6.NDPConfigurations{
AutoGenTempGlobalAddresses: true,
},
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
})},
})
@@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
@@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
@@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
+ AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
HandleRAs: true,
DiscoverDefaultRouters: true,
@@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
}
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ // Make sure the right remote link address is used.
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
+ }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
}
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
- remaining--
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
}
+ }
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
- waitForPkt(defaultAsyncPositiveEventTimeout)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt)
}
+ }
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
+
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
}
func TestStopStartSolicitingRouters(t *testing.T) {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 177bf5516..317f6871d 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -24,9 +24,16 @@ import (
const neighborCacheSize = 512 // max entries per interface
+// NeighborStats holds metrics for the neighbor table.
+type NeighborStats struct {
+ // FailedEntryLookups counts the number of lookups performed on an entry in
+ // Failed state.
+ FailedEntryLookups *tcpip.StatCounter
+}
+
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
-// dynmically acquired entries. It contains the state machine and configuration
+// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
@@ -175,14 +182,15 @@ func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(n.cache))
n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ entries := make([]NeighborEntry, 0, len(n.cache))
for _, entry := range n.cache {
entry.mu.RLock()
entries = append(entries, entry.neigh)
entry.mu.RUnlock()
}
- n.mu.RUnlock()
return entries
}
@@ -226,6 +234,8 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
}
// removeEntryLocked removes the specified entry from the neighbor cache.
+//
+// Prerequisite: n.mu and entry.mu MUST be locked.
func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
if entry.neigh.State != Static {
n.dynamic.lru.Remove(entry)
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index ed33418f3..732a299f7 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -80,17 +80,20 @@ func entryDiffOptsWithSort() []cmp.Option {
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
config.resetInvalidFields()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- return &neighborCache{
+ neigh := &neighborCache{
nic: &NIC{
stack: &Stack{
clock: clock,
nudDisp: nudDisp,
},
- id: 1,
+ id: 1,
+ stats: makeNICStats(),
},
state: NewNUDState(config, rng),
cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
+ neigh.nic.neigh = neigh
+ return neigh
}
// testEntryStore contains a set of IP to NeighborEntry mappings.
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 493e48031..32399b4f5 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -258,7 +258,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
case Failed:
e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() {
e.nic.neigh.removeEntryLocked(e)
})
e.job.Schedule(config.UnreachableTime)
@@ -347,9 +347,10 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
e.setStateLocked(Delay)
e.dispatchChangeEventLocked()
- case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
-
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -511,3 +512,23 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
}
+
+// doubleLock combines two locks into one while maintaining lock ordering.
+//
+// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed
+// neighbor is allowed.
+type doubleLock struct {
+ first, second sync.Locker
+}
+
+// Lock locks both locks in order: first then second.
+func (l *doubleLock) Lock() {
+ l.first.Lock()
+ l.second.Lock()
+}
+
+// Unlock unlocks both locks in reverse order: second then first.
+func (l *doubleLock) Unlock() {
+ l.second.Unlock()
+ l.first.Unlock()
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index c2b763325..c497d3932 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -89,7 +89,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet sent | | Changed |
+// | Stale | Delay | Packet queued | | Changed |
// | Delay | Reachable | Upper-layer confirmation | | Changed |
// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
@@ -101,6 +101,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | Failed | Packet queued | | |
// | Failed | | Unreachability timer expired | Delete entry | |
type testEntryEventType uint8
@@ -228,6 +229,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock: clock,
nudDisp: &disp,
},
+ stats: makeNICStats(),
}
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
@@ -3433,6 +3435,146 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
+func TestEntryFailedToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
+ // their expected state.
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.Advance(waitFor)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ failedLookups := e.nic.stats.Neighbor.FailedEntryLookups
+ if got := failedLookups.Value(); got != 0 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got)
+ }
+
+ e.mu.Lock()
+ // Verify queuing a packet to the entry immediately fails.
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ state := e.neigh.State
+ e.mu.Unlock()
+ if state != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", state, Failed)
+ }
+
+ if got := failedLookups.Value(); got != 1 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got)
+ }
+}
+
func TestEntryFailedGetsDeleted(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3e6ceff28..5d037a27e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -54,18 +54,20 @@ type NIC struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ // packetEPs is protected by mu, but the contained packetEndpointList are
+ // not.
+ packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
-// NICStats includes transmitted and received stats.
+// NICStats hold statistics for a NIC.
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
DisabledRx DirectionStats
+
+ Neighbor NeighborStats
}
func makeNICStats() NICStats {
@@ -80,6 +82,39 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
+type packetEndpointList struct {
+ mu sync.RWMutex
+
+ // eps is protected by mu, but the contained PacketEndpoint values are not.
+ eps []PacketEndpoint
+}
+
+func (p *packetEndpointList) add(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.eps = append(p.eps, ep)
+}
+
+func (p *packetEndpointList) remove(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for i, epOther := range p.eps {
+ if epOther == ep {
+ p.eps = append(p.eps[:i], p.eps[i+1:]...)
+ break
+ }
+ }
+}
+
+// forEach calls fn with each endpoints in p while holding the read lock on p.
+func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ for _, ep := range p.eps {
+ fn(ep)
+ }
+}
+
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -100,7 +135,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -123,11 +158,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = new(packetEndpointList)
}
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
+ nic.mu.packetEPs[netNum] = new(packetEndpointList)
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
@@ -170,7 +205,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -182,6 +217,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.
@@ -265,7 +304,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
@@ -277,9 +316,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
}
@@ -561,8 +600,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -578,11 +616,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
@@ -637,15 +671,23 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
- packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet type sockets that may be listening for all protocols.
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ protoEPs := n.mu.packetEPs[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ // Deliver to interested packet endpoints without holding NIC lock.
+ deliverPacketEPs := func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketHost
ep.HandlePacket(n.id, local, protocol, p)
}
+ if protoEPs != nil {
+ protoEPs.forEach(deliverPacketEPs)
+ }
+ if anyEPs != nil {
+ anyEPs.forEach(deliverPacketEPs)
+ }
// Parse headers.
netProto := n.stack.NetworkProtocolInstance(protocol)
@@ -686,16 +728,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ eps.forEach(func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
- }
+ })
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -848,7 +891,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
if !ok {
return tcpip.ErrNotSupported
}
- n.mu.packetEPs[netProto] = append(eps, ep)
+ eps.add(ep)
return nil
}
@@ -861,13 +904,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
if !ok {
return
}
-
- for i, epOther := range eps {
- if epOther == ep {
- n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
- return
- }
- }
+ eps.remove(ep)
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2cb13c6fa..b334e27c4 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -259,15 +259,6 @@ const (
PacketLoop
)
-// NetOptions is an interface that allows us to pass network protocol specific
-// options through the Stack layer code.
-type NetOptions interface {
- // SizeWithPadding returns the amount of memory that must be allocated to
- // hold the options given that the value must be rounded up to the next
- // multiple of 4 bytes.
- SizeWithPadding() int
-}
-
// NetworkHeaderParams are the header parameters given as input by the
// transport endpoint to the network.
type NetworkHeaderParams struct {
@@ -279,10 +270,6 @@ type NetworkHeaderParams struct {
// TOS refers to TypeOfService or TrafficClass field of the IP-header.
TOS uint8
-
- // Options is a set of options to add to a network header (or nil).
- // It will be protocol specific opaque information from higher layers.
- Options NetOptions
}
// GroupAddressableEndpoint is an endpoint that supports group addressing.
@@ -291,14 +278,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -739,10 +722,6 @@ type LinkEndpoint interface {
// endpoint.
Capabilities() LinkEndpointCapabilities
- // WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header. It takes ownership of vv.
- WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
-
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 53cb6694f..de5fe6ffe 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -18,19 +18,22 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
- // RemoteLinkAddress is the link-layer (MAC) address of the
- // final destination of the route.
- RemoteLinkAddress tcpip.LinkAddress
-
// LocalAddress is the local address where the route starts.
LocalAddress tcpip.Address
@@ -52,8 +55,16 @@ type Route struct {
// address's assigned status without the NIC.
localAddressNIC *NIC
- // localAddressEndpoint is the local address this route is associated with.
- localAddressEndpoint AssignableAddressEndpoint
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+
+ // remoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ remoteLinkAddress tcpip.LinkAddress
+ }
// outgoingNIC is the interface this route uses to write packets.
outgoingNIC *NIC
@@ -71,22 +82,24 @@ type Route struct {
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
- addrWithPrefix := addressEndpoint.AddressWithPrefix()
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
+ if len(localAddr) == 0 {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ }
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
addressEndpoint.DecRef()
- return Route{}
+ return nil
}
// If no remote address is provided, use the local address.
if len(remoteAddr) == 0 {
- remoteAddr = addrWithPrefix.Address
+ remoteAddr = localAddr
}
r := makeRoute(
netProto,
- addrWithPrefix.Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -99,8 +112,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// broadcast it.
if len(gateway) > 0 {
r.NextHop = gateway
- } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.ResolveWith(header.EthernetBroadcastAddress)
}
return r
@@ -108,11 +121,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
+ if len(localAddr) == 0 {
+ localAddr = localAddressEndpoint.AddressWithPrefix().Address
+ }
+
loop := PacketOut
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
@@ -133,18 +150,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
- r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- localAddressEndpoint: localAddressEndpoint,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
@@ -159,7 +179,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -170,6 +190,14 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
+// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
+// the route.
+func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.mu.remoteLinkAddress
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
return r.outgoingNIC.ID()
@@ -231,7 +259,9 @@ func (r *Route) GSOMaxSize() uint32 {
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.mu.remoteLinkAddress = addr
}
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
@@ -244,7 +274,10 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
//
// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
- if !r.IsResolutionRequired() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
return nil, nil
@@ -254,7 +287,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
- r.RemoteLinkAddress = r.LocalLinkAddress
+ r.mu.remoteLinkAddress = r.LocalLinkAddress
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -272,7 +305,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = entry.LinkAddr
+ r.mu.remoteLinkAddress = entry.LinkAddr
return nil, nil
}
@@ -280,7 +313,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = linkAddr
+ r.mu.remoteLinkAddress = linkAddr
return nil, nil
}
@@ -309,7 +342,13 @@ func (r *Route) local() bool {
//
// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isResolutionRequiredRLocked()
+}
+
+func (r *Route) isResolutionRequiredRLocked() bool {
+ if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
return false
}
@@ -317,11 +356,18 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isValidForOutgoing() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isValidForOutgoingRLocked()
+}
+
+func (r *Route) isValidForOutgoingRLocked() bool {
if !r.outgoingNIC.Enabled() {
return false
}
- if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
@@ -375,37 +421,44 @@ func (r *Route) MTU() uint32 {
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.localAddressEndpoint != nil {
- r.localAddressEndpoint.DecRef()
- r.localAddressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.mu.localAddressEndpoint != nil {
+ r.mu.localAddressEndpoint.DecRef()
+ r.mu.localAddressEndpoint = nil
}
}
// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.localAddressEndpoint != nil {
- if !r.localAddressEndpoint.IncRef() {
- panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
- }
+func (r *Route) Clone() *Route {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ newRoute := &Route{
+ RemoteAddress: r.RemoteAddress,
+ LocalAddress: r.LocalAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ NextHop: r.NextHop,
+ NetProto: r.NetProto,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
- return *r
-}
-// MakeLoopedRoute duplicates the given route with special handling for routes
-// used for sending multicast or broadcast packets. In those cases the
-// multicast/broadcast address is the remote address when sending out, but for
-// incoming (looped) packets it becomes the local address. Similarly, the local
-// interface address that was the local address going out becomes the remote
-// address coming in. This is different to unicast routes where local and
-// remote addresses remain the same as they identify location (local vs remote)
-// not direction (source vs destination).
-func (r *Route) MakeLoopedRoute() Route {
- l := r.Clone()
- if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
- l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
- l.RemoteLinkAddress = l.LocalLinkAddress
+ newRoute.mu.Lock()
+ defer newRoute.mu.Unlock()
+ newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint
+ if newRoute.mu.localAddressEndpoint != nil {
+ if !newRoute.mu.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress))
+ }
}
- return l
+ newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress
+
+ return newRoute
}
// Stack returns the instance of the Stack that owns this route.
@@ -418,7 +471,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.localAddressEndpoint.Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -428,27 +488,3 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
-
-// isInboundBroadcast returns true if the route is for an inbound broadcast
-// packet.
-func (r *Route) isInboundBroadcast() bool {
- // Only IPv4 has a notion of broadcast.
- return r.isV4Broadcast(r.LocalAddress)
-}
-
-// ReverseRoute returns new route with given source and destination address.
-func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
- return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- localAddressNIC: r.localAddressNIC,
- localAddressEndpoint: r.localAddressEndpoint,
- outgoingNIC: r.outgoingNIC,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e0025e0a9..026d330c4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -171,6 +171,9 @@ type TCPSenderState struct {
// Outstanding is the number of packets in flight.
Outstanding int
+ // SackedOut is the number of packets which have been selectively acked.
+ SackedOut int
+
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
@@ -1118,6 +1121,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
+// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
+// the address prefix.
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+ ap := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: addr,
+ }
+ return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
+}
+
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
@@ -1208,10 +1221,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
- return Route{}, false
+ return nil
}
var outgoingNIC *NIC
@@ -1235,12 +1248,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// route.
if outgoingNIC == nil {
localAddressEndpoint.DecRef()
- return Route{}, false
+ return nil
}
r := makeLocalRoute(
netProto,
- localAddressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -1249,10 +1262,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
if r.IsOutboundBroadcast() {
r.Release()
- return Route{}, false
+ return nil
}
- return r, true
+ return r
}
// findLocalRouteRLocked returns a local route.
@@ -1261,26 +1274,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// is, a local route is a route where packets never have to leave the stack.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
if len(localAddr) == 0 {
localAddr = remoteAddr
}
if localAddressNICID == 0 {
for _, localAddressNIC := range s.nics {
- if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
- return r, true
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
}
}
- return Route{}, false
+ return nil
}
if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
}
- return Route{}, false
+ return nil
}
// FindRoute creates a route to the given destination address, leaving through
@@ -1294,7 +1307,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1305,7 +1318,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
if s.handleLocal && !isMulticast && !isLocalBroadcast {
- if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
return r, nil
}
}
@@ -1317,7 +1330,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
return makeRoute(
netProto,
- addressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
nic, /* outboundNIC */
nic, /* localAddressNIC*/
@@ -1329,9 +1342,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1354,8 +1367,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if needRoute {
gateway = route.Gateway
}
- r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
- if r == (Route{}) {
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
+ if r == nil {
panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
}
return r, nil
@@ -1391,13 +1404,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 {
if aNIC, ok := s.nics[id]; ok {
if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if id == 0 {
@@ -1409,7 +1422,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
@@ -1417,12 +1430,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1810,49 +1823,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
nic.unregisterPacketEndpoint(netProto, ep)
}
-// WritePacket writes data directly to the specified NIC. It adds an ethernet
-// header based on the arguments.
-func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
- s.mu.Lock()
- nic, ok := s.nics[nicID]
- s.mu.Unlock()
- if !ok {
- return tcpip.ErrUnknownDevice
- }
-
- // Add our own fake ethernet header.
- ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
- DstAddr: dst,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- vv := buffer.View(fakeHeader).ToVectorisedView()
- vv.Append(payload)
-
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
- return err
- }
-
- return nil
-}
-
-// WriteRawPacket writes data directly to the specified NIC without adding any
-// headers.
-func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+// WritePacketToRemote writes a payload on the specified NIC using the provided
+// network protocol and remote link address.
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
}
-
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
- return err
- }
-
- return nil
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()),
+ Data: payload,
+ })
+ return nic.WritePacketToRemote(remote, nil, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
@@ -1912,7 +1896,6 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- // TODO: notify network of subscription via igmp protocol.
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2159,3 +2142,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
}
return protos
}
+
+func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ return subnet.IsBroadcast(addr)
+}
+
+// IsSubnetBroadcast returns true if the provided address is a subnet-local
+// broadcast address on the specified NIC and protocol.
+//
+// Returns false if the NIC is unknown or if the protocol is unknown or does
+// not support addressing.
+//
+// If the NIC is not specified, the stack will check all NICs.
+func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nicID != 0 {
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return false
+ }
+
+ return isSubnetBroadcastOnNIC(nic, protocol, addr)
+ }
+
+ for _, nic := range s.nics {
+ if isSubnetBroadcastOnNIC(nic, protocol, addr) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 61db3164b..457990945 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,7 +27,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -407,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -425,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -436,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1563,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
}
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", got, want)
}
if gotRoute.NextHop != wantRoute.NextHop {
return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
@@ -1603,7 +1602,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1657,7 +1656,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1667,7 +1666,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1683,7 +1682,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2407,9 +2406,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ AutoGenLinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
})},
}
@@ -2502,8 +2501,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ AutoGenLinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
})},
}
@@ -2536,9 +2535,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ AutoGenLinkLocal: true,
+ NDPDisp: &ndpDisp,
})},
}
@@ -3351,11 +3350,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
remNetSubnetBcast := remNetSubnet.Broadcast()
tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedRoute stack.Route
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedLocalAddress tcpip.Address
+ expectedRemoteAddress tcpip.Address
+ expectedRemoteLinkAddress tcpip.LinkAddress
+ expectedNextHop tcpip.Address
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLoop stack.PacketLooping
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3370,14 +3374,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: ipv4SubnetBcast,
- RemoteLinkAddress: header.EthernetBroadcastAddress,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut | stack.PacketLoop,
- },
+ remoteAddr: ipv4SubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: ipv4SubnetBcast,
+ expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut | stack.PacketLoop,
},
// Broadcast to a locally attached /31 subnet does not populate the
// broadcast MAC.
@@ -3393,13 +3395,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix31.Address,
- RemoteAddress: ipv4Subnet31Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedLocalAddress: ipv4AddrPrefix31.Address,
+ expectedRemoteAddress: ipv4Subnet31Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a locally attached /32 subnet does not populate the
// broadcast MAC.
@@ -3415,13 +3415,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix32.Address,
- RemoteAddress: ipv4Subnet32Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedLocalAddress: ipv4AddrPrefix32.Address,
+ expectedRemoteAddress: ipv4Subnet32Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// IPv6 has no notion of a broadcast.
{
@@ -3436,13 +3434,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv6Addr.Address,
- RemoteAddress: ipv6SubnetBcast,
- NetProto: header.IPv6ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv6SubnetBcast,
+ expectedLocalAddress: ipv6Addr.Address,
+ expectedRemoteAddress: ipv6SubnetBcast,
+ expectedNetProto: header.IPv6ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a remote subnet in the route table is send to the next-hop
// gateway.
@@ -3459,14 +3455,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to an unknown subnet follows the default route. Note that this
// is essentially just routing an unknown destination IP, because w/o any
@@ -3484,14 +3478,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
}
@@ -3520,10 +3512,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
+ if err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
- t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ if r.LocalAddress != test.expectedLocalAddress {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress)
+ }
+ if r.RemoteAddress != test.expectedRemoteAddress {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress)
+ }
+ if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
+ t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
+ }
+ if r.NextHop != test.expectedNextHop {
+ t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop)
+ }
+ if r.NetProto != test.expectedNetProto {
+ t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto)
+ }
+ if r.Loop != test.expectedLoop {
+ t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop)
}
})
}
@@ -4091,10 +4100,12 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
if err != test.findRouteErr {
t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
}
- defer r.Release()
if test.findRouteErr != nil {
return
@@ -4152,3 +4163,63 @@ func TestFindRouteWithForwarding(t *testing.T) {
})
}
}
+
+func TestWritePacketToRemote(t *testing.T) {
+ const nicID = 1
+ const MTU = 1280
+ e := channel.New(1, MTU, linkAddr1)
+ s := stack.New(stack.Options{})
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("CreateNIC(%d) = %s", nicID, err)
+ }
+ tests := []struct {
+ name string
+ protocol tcpip.NetworkProtocolNumber
+ payload []byte
+ }{
+ {
+ name: "SuccessIPv4",
+ protocol: header.IPv4ProtocolNumber,
+ payload: []byte{1, 2, 3, 4},
+ },
+ {
+ name: "SuccessIPv6",
+ protocol: header.IPv6ProtocolNumber,
+ payload: []byte{5, 6, 7, 8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
+ }
+
+ pkt, ok := e.Read()
+ if got, want := ok, true; got != want {
+ t.Fatalf("e.Read() = %t, want %t", got, want)
+ }
+ if got, want := pkt.Proto, test.protocol; got != want {
+ t.Fatalf("pkt.Proto = %d, want %d", got, want)
+ }
+ if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+
+ t.Run("InvalidNICID", func(t *testing.T) {
+ if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ }
+ pkt, ok := e.Read()
+ if got, want := ok, false; got != want {
+ t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
+ }
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 41a8e5ad0..a692af20b 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -141,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -307,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
- t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
- }
+ ep.SocketOptions().SetReusePort(endpoint.reuse)
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5b9043d85..66eb562ba 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -38,14 +38,15 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
+ acceptQueue []*fakeTransportEndpoint
// ops is used to set and get socket options.
ops tcpip.SocketOptions
@@ -64,8 +65,11 @@ func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep.ops.InitHandler(ep)
+ return ep
}
func (f *fakeTransportEndpoint) Abort() {
@@ -105,8 +109,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -114,21 +118,11 @@ func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Erro
return tcpip.ErrInvalidEndpointState
}
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
// SetSockOptInt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -189,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
- a := &f.acceptQueue[0]
+ a := f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
return a, nil, nil
}
@@ -206,7 +200,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
); err != nil {
return err
}
- f.acceptQueue = []fakeTransportEndpoint{}
+ f.acceptQueue = []*fakeTransportEndpoint{}
return nil
}
@@ -232,7 +226,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
}
route.ResolveWith(pkt.SourceLinkAddress())
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -240,7 +234,9 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
proto: f.proto,
peerAddr: route.RemoteAddress,
route: route,
- })
+ }
+ ep.ops.InitHandler(ep)
+ f.acceptQueue = append(f.acceptQueue, ep)
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f9e83dd1c..45fa62720 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -49,8 +49,9 @@ const ipv4AddressSize = 4
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
-// Note: to support save / restore, it is important that all tcpip errors have
-// distinct error messages.
+// All errors must have unique msg strings.
+//
+// +stateify savable
type Error struct {
msg string
@@ -247,6 +248,16 @@ func (a Address) WithPrefix() AddressWithPrefix {
}
}
+// Unspecified returns true if the address is unspecified.
+func (a Address) Unspecified() bool {
+ for _, b := range a {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -481,6 +492,14 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+
+ // HasOriginalDestinationAddress indicates whether OriginalDstAddress is
+ // set.
+ HasOriginalDstAddress bool
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress FullAddress
}
// PacketOwner is used to get UID and GID of the packet.
@@ -535,7 +554,7 @@ type Endpoint interface {
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek([][]byte) (int64, ControlMessages, *Error)
+ Peek([][]byte) (int64, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -593,10 +612,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt SettableSocketOption) *Error
- // SetSockOptBool sets a socket option, for simple cases where a value
- // has the bool type.
- SetSockOptBool(opt SockOptBool, v bool) *Error
-
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
SetSockOptInt(opt SockOptInt, v int) *Error
@@ -604,10 +619,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt GettableSocketOption) *Error
- // GetSockOptBool gets a socket option for simple cases where a return
- // value has the bool type.
- GetSockOptBool(SockOptBool) (bool, *Error)
-
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
GetSockOptInt(SockOptInt) (int, *Error)
@@ -694,79 +705,6 @@ type WriteOptions struct {
Atomic bool
}
-// SockOptBool represents socket options which values have the bool type.
-type SockOptBool int
-
-const (
- // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be held until segments are full by the TCP transport
- // protocol.
- CorkOption SockOptBool = iota
-
- // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be sent out immediately by the transport protocol. For
- // TCP, it determines if the Nagle algorithm is on or off.
- DelayOption
-
- // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether TCP keepalive is enabled for this socket.
- KeepaliveEnabledOption
-
- // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether multicast packets sent over a non-loopback interface
- // will be looped back.
- MulticastLoopOption
-
- // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether UDP checksum is disabled for this socket.
- NoChecksumOption
-
- // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether SCM_CREDENTIALS socket control messages are enabled.
- //
- // Only supported on Unix sockets.
- PasscredOption
-
- // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
- QuickAckOption
-
- // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
- // specify if the IPV6_TCLASS ancillary message is passed with incoming
- // packets.
- ReceiveTClassOption
-
- // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
- // if the TOS ancillary message is passed with incoming packets.
- ReceiveTOSOption
-
- // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
- // specify if more inforamtion is provided with incoming packets such as
- // interface index and address.
- ReceiveIPPacketInfoOption
-
- // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether Bind() should allow reuse of local address.
- ReuseAddressOption
-
- // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
- // multiple sockets to be bound to an identical socket address.
- ReusePortOption
-
- // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether an IPv6 socket is to be restricted to sending and receiving
- // IPv6 packets only.
- V6OnlyOption
-
- // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
- // endpoint that all packets being written have an IP header and the
- // endpoint should not attach an IP header.
- IPHdrIncludedOption
-
- // AcceptConnOption is used by GetSockOptBool to indicate if the
- // socket is a listening socket.
- AcceptConnOption
-)
-
// SockOptInt represents socket options which values have the int type.
type SockOptInt int
@@ -1158,14 +1096,6 @@ type RemoveMembershipOption MembershipOption
func (*RemoveMembershipOption) isSettableSocketOption() {}
-// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP out-of-band data is delivered along with the normal in-band data.
-type OutOfBandInlineOption int
-
-func (*OutOfBandInlineOption) isGettableSocketOption() {}
-
-func (*OutOfBandInlineOption) isSettableSocketOption() {}
-
// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
// classic BPF filter on a given endpoint.
type SocketDetachFilterOption int
@@ -1215,10 +1145,6 @@ type LingerOption struct {
Timeout time.Duration
}
-func (*LingerOption) isGettableSocketOption() {}
-
-func (*LingerOption) isSettableSocketOption() {}
-
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
@@ -1389,6 +1315,18 @@ type ICMPv6PacketStats struct {
// RedirectMsg is the total number of ICMPv6 redirect message packets
// counted.
RedirectMsg *StatCounter
+
+ // MulticastListenerQuery is the total number of Multicast Listener Query
+ // messages counted.
+ MulticastListenerQuery *StatCounter
+
+ // MulticastListenerReport is the total number of Multicast Listener Report
+ // messages counted.
+ MulticastListenerReport *StatCounter
+
+ // MulticastListenerDone is the total number of Multicast Listener Done
+ // messages counted.
+ MulticastListenerDone *StatCounter
}
// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
@@ -1430,6 +1368,10 @@ type ICMPv6SentPacketStats struct {
type ICMPv6ReceivedPacketStats struct {
ICMPv6PacketStats
+ // Unrecognized is the total number of ICMPv6 packets received that the
+ // transport layer does not know how to parse.
+ Unrecognized *StatCounter
+
// Invalid is the total number of ICMPv6 packets received that the
// transport layer could not parse.
Invalid *StatCounter
@@ -1439,25 +1381,90 @@ type ICMPv6ReceivedPacketStats struct {
RouterOnlyPacketsDroppedByHost *StatCounter
}
-// ICMPStats collects ICMP-specific stats (both v4 and v6).
-type ICMPStats struct {
+// ICMPv4Stats collects ICMPv4-specific stats.
+type ICMPv4Stats struct {
// ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
// and a single count of packets which failed to write to the link
// layer.
- V4PacketsSent ICMPv4SentPacketStats
+ PacketsSent ICMPv4SentPacketStats
// ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
// packet type and a single count of invalid packets received.
- V4PacketsReceived ICMPv4ReceivedPacketStats
+ PacketsReceived ICMPv4ReceivedPacketStats
+}
+// ICMPv6Stats collects ICMPv6-specific stats.
+type ICMPv6Stats struct {
// ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
// and a single count of packets which failed to write to the link
// layer.
- V6PacketsSent ICMPv6SentPacketStats
+ PacketsSent ICMPv6SentPacketStats
// ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
// packet type and a single count of invalid packets received.
- V6PacketsReceived ICMPv6ReceivedPacketStats
+ PacketsReceived ICMPv6ReceivedPacketStats
+}
+
+// ICMPStats collects ICMP-specific stats (both v4 and v6).
+type ICMPStats struct {
+ // V4 contains the ICMPv4-specifics stats.
+ V4 ICMPv4Stats
+
+ // V6 contains the ICMPv4-specifics stats.
+ V6 ICMPv6Stats
+}
+
+// IGMPPacketStats enumerates counts for all IGMP packet types.
+type IGMPPacketStats struct {
+ // MembershipQuery is the total number of Membership Query messages counted.
+ MembershipQuery *StatCounter
+
+ // V1MembershipReport is the total number of Version 1 Membership Report
+ // messages counted.
+ V1MembershipReport *StatCounter
+
+ // V2MembershipReport is the total number of Version 2 Membership Report
+ // messages counted.
+ V2MembershipReport *StatCounter
+
+ // LeaveGroup is the total number of Leave Group messages counted.
+ LeaveGroup *StatCounter
+}
+
+// IGMPSentPacketStats collects outbound IGMP-specific stats.
+type IGMPSentPacketStats struct {
+ IGMPPacketStats
+
+ // Dropped is the total number of IGMP packets dropped.
+ Dropped *StatCounter
+}
+
+// IGMPReceivedPacketStats collects inbound IGMP-specific stats.
+type IGMPReceivedPacketStats struct {
+ IGMPPacketStats
+
+ // Invalid is the total number of IGMP packets received that IGMP could not
+ // parse.
+ Invalid *StatCounter
+
+ // ChecksumErrors is the total number of IGMP packets dropped due to bad
+ // checksums.
+ ChecksumErrors *StatCounter
+
+ // Unrecognized is the total number of unrecognized messages counted, these
+ // are silently ignored for forward-compatibilty.
+ Unrecognized *StatCounter
+}
+
+// IGMPStats colelcts IGMP-specific stats.
+type IGMPStats struct {
+ // IGMPSentPacketStats contains counts of sent packets by IGMP packet type
+ // and a single count of invalid packets received.
+ PacketsSent IGMPSentPacketStats
+
+ // IGMPReceivedPacketStats contains counts of received packets by IGMP packet
+ // type and a single count of invalid packets received.
+ PacketsReceived IGMPReceivedPacketStats
}
// IPStats collects IP-specific stats (both v4 and v6).
@@ -1665,6 +1672,9 @@ type Stats struct {
// ICMP breaks out ICMP-specific stats (both v4 and v6).
ICMP ICMPStats
+ // IGMP breaks out IGMP-specific stats.
+ IGMP IGMPStats
+
// IP breaks out IP-specific stats (both v4 and v6).
IP IPStats
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 1c8e2bc34..c461da137 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -226,3 +226,47 @@ func TestAddressWithPrefixSubnet(t *testing.T) {
}
}
}
+
+func TestAddressUnspecified(t *testing.T) {
+ tests := []struct {
+ addr Address
+ unspecified bool
+ }{
+ {
+ addr: "",
+ unspecified: true,
+ },
+ {
+ addr: "\x00",
+ unspecified: true,
+ },
+ {
+ addr: "\x01",
+ unspecified: false,
+ },
+ {
+ addr: "\x00\x00",
+ unspecified: true,
+ },
+ {
+ addr: "\x01\x00",
+ unspecified: false,
+ },
+ {
+ addr: "\x00\x01",
+ unspecified: false,
+ },
+ {
+ addr: "\x01\x01",
+ unspecified: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) {
+ if got := test.addr.Unspecified(); got != test.unspecified {
+ t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 9b0f3b675..800025fb9 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -25,6 +25,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 421da1add..baaa741cd 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -70,8 +71,8 @@ func TestInitialLoopbackAddresses(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDispatcher{},
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDispatcher{},
+ AutoGenLinkLocal: true,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
@@ -93,9 +94,10 @@ func TestInitialLoopbackAddresses(t *testing.T) {
}
}
-// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers
-// itself bound to all addresses in the subnet of an assigned address.
-func TestLoopbackAcceptAllInSubnet(t *testing.T) {
+// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers
+// itself bound to all addresses in the subnet of an assigned address and UDP
+// traffic is sent/received correctly.
+func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
const (
nicID = 1
localPort = 80
@@ -107,7 +109,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
Protocol: header.IPv4ProtocolNumber,
AddressWithPrefix: ipv4Addr,
}
- ipv4Bytes := []byte(ipv4Addr.Address)
+ ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
ipv4Bytes[len(ipv4Bytes)-1]++
otherIPv4Address := tcpip.Address(ipv4Bytes)
@@ -129,7 +131,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
{
name: "IPv4 bind to wildcard and send to assigned address",
addAddress: ipv4ProtocolAddress,
- dstAddr: ipv4Addr.Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
expectRx: true,
},
{
@@ -148,7 +150,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
name: "IPv4 bind to other subnet-local address and send to assigned address",
addAddress: ipv4ProtocolAddress,
bindAddr: otherIPv4Address,
- dstAddr: ipv4Addr.Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
expectRx: false,
},
{
@@ -161,7 +163,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
{
name: "IPv4 bind to assigned address and send to other subnet-local address",
addAddress: ipv4ProtocolAddress,
- bindAddr: ipv4Addr.Address,
+ bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
dstAddr: otherIPv4Address,
expectRx: false,
},
@@ -236,13 +238,17 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want)
}
- if gotPayload, _, err := rep.Read(nil); test.expectRx {
+ var addr tcpip.FullAddress
+ if gotPayload, _, err := rep.Read(&addr); test.expectRx {
if err != nil {
- t.Fatalf("reep.Read(nil): %s", err)
+ t.Fatalf("reep.Read(_): %s", err)
}
if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ if addr.Addr != test.addAddress.AddressWithPrefix.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
+ }
} else {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
@@ -312,3 +318,168 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
}
}
+
+// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers
+// itself bound to all addresses in the subnet of an assigned address and TCP
+// traffic is sent/received correctly.
+func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 80
+ )
+
+ ipv4ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8
+ ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
+ ipv4Bytes[len(ipv4Bytes)-1]++
+ otherIPv4Address := tcpip.Address(ipv4Bytes)
+
+ ipv6ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ }
+ ipv6Bytes := []byte(ipv6Addr.Address)
+ ipv6Bytes[len(ipv6Bytes)-1]++
+ otherIPv6Address := tcpip.Address(ipv6Bytes)
+
+ tests := []struct {
+ name string
+ addAddress tcpip.ProtocolAddress
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectAccept bool
+ }{
+ {
+ name: "IPv4 bind to wildcard and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: otherIPv4Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to wildcard send to other address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: remoteIPv4Addr,
+ expectAccept: false,
+ },
+ {
+ name: "IPv4 bind to other subnet-local address and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ expectAccept: false,
+ },
+ {
+ name: "IPv4 bind and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: otherIPv4Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to assigned address and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ dstAddr: otherIPv4Address,
+ expectAccept: false,
+ },
+
+ {
+ name: "IPv6 bind and send to assigned address",
+ addAddress: ipv6ProtocolAddress,
+ bindAddr: ipv6Addr.Address,
+ dstAddr: ipv6Addr.Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv6 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv6ProtocolAddress,
+ dstAddr: otherIPv6Address,
+ expectAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer listeningEndpoint.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
+ if err := listeningEndpoint.Bind(bindAddr); err != nil {
+ t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err)
+ }
+
+ if err := listeningEndpoint.Listen(1); err != nil {
+ t.Fatalf("listeningEndpoint.Listen(1): %s", err)
+ }
+
+ connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer connectingEndpoint.Close()
+
+ connectAddr := tcpip.FullAddress{
+ Addr: test.dstAddr,
+ Port: localPort,
+ }
+ if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
+ }
+
+ if !test.expectAccept {
+ if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+ return
+ }
+
+ // Wait for the listening endpoint to be "readable". That is, wait for a
+ // new connection.
+ <-ch
+ var addr tcpip.FullAddress
+ if _, _, err := listeningEndpoint.Accept(&addr); err != nil {
+ t.Fatalf("listeningEndpoint.Accept(nil): %s", err)
+ }
+ if addr.Addr != test.addAddress.AddressWithPrefix.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 9d30329f5..2e59f6a42 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -96,11 +96,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -272,11 +272,11 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLen),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -510,10 +510,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
- }
-
+ ep.SocketOptions().SetReuseAddress(true)
ep.SocketOptions().SetBroadcast(true)
bindAddr := tcpip.FullAddress{Port: localPort}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 440cb0352..74fe19e98 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -49,6 +49,7 @@ const (
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
@@ -71,11 +72,9 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -85,7 +84,7 @@ type endpoint struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return &endpoint{
+ ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -96,7 +95,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
sndBufSize: 32 * 1024,
state: stateInitial,
uniqueID: s.UniqueID(),
- }, nil
+ }
+ ep.ops.InitHandler(ep)
+ return ep, nil
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -129,7 +130,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = stateClosed
@@ -142,6 +146,7 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
@@ -267,26 +272,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- if to == nil {
- route = &e.route
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in Route.Resolve()
- // call below.
- e.mu.RUnlock()
- defer e.mu.RLock()
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // Recheck state after lock was re-acquired.
- if e.state != stateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
- }
- } else {
+ route := e.route
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -310,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
}
if route.IsResolutionRequired() {
@@ -343,26 +330,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.SocketDetachFilterOption:
- return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- }
- return nil
-}
-
-// SetSockOptBool sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
@@ -378,17 +351,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -426,16 +388,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
@@ -857,6 +810,7 @@ func (*endpoint) LastError() *tcpip.Error {
return nil
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 3bff3755a..9faab4b9e 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -60,6 +60,8 @@ type packet struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
@@ -83,8 +85,6 @@ type endpoint struct {
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
boundNIC tcpip.NICID
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -107,6 +107,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
+ ep.ops.InitHandler(ep)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -203,8 +204,8 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
}
// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
@@ -303,26 +304,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- ep.mu.Lock()
- ep.linger = *v
- ep.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -378,26 +368,7 @@ func (ep *endpoint) LastError() *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- ep.mu.Lock()
- *o = ep.linger
- ep.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrNotSupported
- }
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.AcceptConnOption:
- return false, nil
- default:
- return false, tcpip.ErrNotSupported
- }
+ return tcpip.ErrNotSupported
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -551,8 +522,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 4ae1f92ab..87c60bdab 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -58,12 +58,13 @@ type rawPacket struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
- hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -82,10 +83,8 @@ type endpoint struct {
bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -114,8 +113,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
- hdrIncluded: !associated,
}
+ e.ops.InitHandler(e)
+ e.ops.SetHeaderIncluded(!associated)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -170,9 +170,11 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- if e.connected {
+ e.connected = false
+
+ if e.route != nil {
e.route.Release()
- e.connected = false
+ e.route = nil
}
e.closed = true
@@ -223,6 +225,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrInvalidOptionValue
}
+ if opts.To != nil {
+ // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -266,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -296,7 +305,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if e.route.IsResolutionRequired() {
- savedRoute := &e.route
+ savedRoute := e.route
// Promote lock to exclusive if using a shared route,
// given that it may need to change in finishWrite.
e.mu.RUnlock()
@@ -304,7 +313,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Make sure that the route didn't change during the
// time we didn't hold the lock.
- if !e.connected || savedRoute != &e.route {
+ if !e.connected || savedRoute != e.route {
e.mu.Unlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
@@ -314,7 +323,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return n, ch, err
}
- n, ch, err := e.finishWrite(payloadBytes, &e.route)
+ n, ch, err := e.finishWrite(payloadBytes, e.route)
e.mu.RUnlock()
return n, ch, err
}
@@ -335,7 +344,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, err
}
- n, ch, err := e.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
e.mu.RUnlock()
return n, ch, err
@@ -356,7 +365,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
@@ -382,8 +391,8 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -393,6 +402,11 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -516,33 +530,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- e.hdrIncluded = v
- e.mu.Unlock()
- return nil
- }
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -589,33 +585,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- v := e.hdrIncluded
- e.mu.Unlock()
- return v, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -756,10 +726,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+// LastError implements tcpip.Endpoint.LastError.
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 7d97cbdc7..4a7e1c039 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if e.connected {
var err *tcpip.Error
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
+ // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
+ // remote address. We used to pass e.remote.RemoteAddress which was
+ // effectively the empty address but since moving e.route to hold a pointer
+ // to a route instead of the route by value, we pass the empty address
+ // directly. Obviously this was always wrong since we should provide the
+ // remote address we were connected to, to properly restore the route.
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 518449602..cf232b508 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -45,7 +45,9 @@ go_library(
"rcv.go",
"rcv_state.go",
"reno.go",
+ "reno_recovery.go",
"sack.go",
+ "sack_recovery.go",
"sack_scoreboard.go",
"segment.go",
"segment_heap.go",
@@ -91,7 +93,7 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
- shard_count = 10,
+ shard_count = more_shards,
deps = [
":tcp",
"//pkg/rand",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6e5adc383..3e1041cbe 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
route.ResolveWith(s.remoteLinkAddr)
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6Only
+ n.ops.SetV6Only(l.v6Only)
n.ID = s.id
n.boundNICID = s.nicID
n.route = route
@@ -599,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
}
- if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
return err
}
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
@@ -752,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
- v6Only := e.v6only
+ v6Only := e.ops.GetV6Only()
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index ac6d879a7..c944dccc0 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -16,6 +16,7 @@ package tcp
import (
"encoding/binary"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -133,7 +134,7 @@ func FindWndScale(wnd seqnum.Size) int {
return 0
}
- max := seqnum.Size(0xffff)
+ max := seqnum.Size(math.MaxUint16)
s := 0
for wnd > max && s < header.MaxWndScale {
s++
@@ -300,7 +301,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if ttl == 0 {
ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -361,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -496,7 +497,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
}
@@ -547,7 +548,7 @@ func (h *handshake) start() *tcpip.Error {
}
h.sendSYNOpts = synOpts
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -575,7 +576,6 @@ func (h *handshake) complete() *tcpip.Error {
return err
}
defer timer.stop()
-
for h.state != handshakeCompleted {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
// throughout handshake processing).
@@ -597,7 +597,7 @@ func (h *handshake) complete() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -631,9 +631,8 @@ func (h *handshake) complete() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
-
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
return err
@@ -820,8 +819,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
data = data.Clone(nil)
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
mss := int(gso.MSS)
@@ -865,8 +864,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// network endpoint and under the provided identity.
func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
@@ -941,7 +940,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, tcpFields{
+ err := e.sendTCP(e.route, tcpFields{
id: e.ID,
ttl: e.ttl,
tos: e.sendTOS,
@@ -1002,7 +1001,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
@@ -1080,7 +1079,7 @@ func (e *endpoint) transitionToStateCloseLocked() {
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
- if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
@@ -1141,7 +1140,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
- e.HardError = tcpip.ErrAborted
+ e.hardError = tcpip.ErrAborted
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@@ -1286,7 +1285,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
userTimeout := e.userTimeout
e.keepalive.Lock()
- if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ if !e.SocketOptions().GetKeepAlive() || !e.keepalive.timer.checkExpiration() {
e.keepalive.Unlock()
return nil
}
@@ -1323,7 +1322,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
}
// Start the keepalive timer IFF it's enabled and there is no pending
// data to send.
- if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
e.keepalive.timer.disable()
e.keepalive.Unlock()
return
@@ -1353,7 +1352,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
epilogue := func() {
// e.mu is expected to be hold upon entering this section.
-
if e.snd != nil {
e.snd.resendTimer.cleanup()
}
@@ -1383,7 +1381,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
e.workerCleanup = true
// Lock released below.
@@ -1638,7 +1636,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
}
extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
if newSyn {
- info := e.EndpointInfo.TransportEndpointInfo
+ info := e.TransportEndpointInfo
newID := info.ID
newID.RemoteAddress = ""
newID.RemotePort = 0
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index a6f25896b..1d1b01a6c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -405,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) {
}
}
- // Make sure we get the same error when calling the original ep and the
- // new one. This validates that v4-mapped endpoints are still able to
- // query the V6Only flag, whereas pure v4 endpoints are not.
- _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
- t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
- }
-
// Check the peer address.
addr, err := nep.GetRemoteAddress()
if err != nil {
@@ -530,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
- nep, _, err := c.EP.Accept(&addr)
+ _, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept(&addr)
+ _, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -548,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) {
if addr.Addr != context.TestV6Addr {
t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
}
-
- // Make sure we can still query the v6 only status of the new endpoint,
- // that is, that it is in fact a v6 socket.
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
- }
}
func TestV4AcceptOnV4(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4f4f4c65e..bb0795f78 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -310,16 +310,12 @@ type Stats struct {
func (*Stats) IsEndpointStats() {}
// EndpointInfo holds useful information about a transport endpoint which
-// can be queried by monitoring tools.
+// can be queried by monitoring tools. This exists to allow tcp-only state to
+// be exposed.
//
// +stateify savable
type EndpointInfo struct {
stack.TransportEndpointInfo
-
- // HardError is meaningful only when state is stateError. It stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. HardError is protected by endpoint mu.
- HardError *tcpip.Error `state:".(string)"`
}
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -367,6 +363,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
// +stateify savable
type endpoint struct {
EndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// endpointEntry is used to queue endpoints for processing to the
// a given tcp processor goroutine.
@@ -386,6 +383,11 @@ type endpoint struct {
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
+ // hardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by endpoint mu.
+ hardError *tcpip.Error `state:".(string)"`
+
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -421,7 +423,10 @@ type endpoint struct {
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
- mu sync.Mutex `state:"nosave"`
+ //
+ // During handshake, mu is locked by the protocol listen goroutine and
+ // released by the handshake completion goroutine.
+ mu sync.CrossGoroutineMutex `state:"nosave"`
ownedByUser uint32
// state must be read/set using the EndpointState()/setEndpointState()
@@ -436,9 +441,8 @@ type endpoint struct {
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
boundNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
- v6only bool
isConnectNotified bool
// h stores a reference to the current handshake state if the endpoint is in
@@ -506,24 +510,9 @@ type endpoint struct {
// delay is a boolean (0 is false) and must be accessed atomically.
delay uint32
- // cork holds back segments until full.
- //
- // cork is a boolean (0 is false) and must be accessed atomically.
- cork uint32
-
// scoreboard holds TCP SACK Scoreboard information for this endpoint.
scoreboard *SACKScoreboard
- // The options below aren't implemented, but we remember the user
- // settings because applications expect to be able to set/query these
- // options.
-
- // slowAck holds the negated state of quick ack. It is stubbed out and
- // does nothing.
- //
- // slowAck is a boolean (0 is false) and must be accessed atomically.
- slowAck uint32
-
// segmentQueue is used to hand received segments to the protocol
// goroutine. Segments are queued as long as the queue is not full,
// and dropped when it is.
@@ -685,9 +674,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -701,7 +687,7 @@ func (e *endpoint) UniqueID() uint64 {
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
// r, it will be used; otherwise, the maximum possible MSS will be used.
-func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
// TODO(b/143359391): Respect TCP Min and Max size.
maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
@@ -850,7 +836,6 @@ func (e *endpoint) recentTimestamp() uint32 {
// +stateify savable
type keepalive struct {
sync.Mutex `state:"nosave"`
- enabled bool
idle time.Duration
interval time.Duration
count int
@@ -884,6 +869,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
+ e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
+ e.ops.SetQuickAck(true)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -907,7 +895,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var de tcpip.TCPDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
- e.SetSockOptBool(tcpip.DelayOption, true)
+ e.ops.SetDelayOption(true)
}
var tcpLT tcpip.TCPLingerTimeoutOption
@@ -1049,7 +1037,8 @@ func (e *endpoint) Close() {
return
}
- if e.linger.Enabled && e.linger.Timeout == 0 {
+ linger := e.SocketOptions().GetLinger()
+ if linger.Enabled && linger.Timeout == 0 {
s := e.EndpointState()
isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
if isResetState {
@@ -1169,7 +1158,11 @@ func (e *endpoint) cleanupLocked() {
e.boundPortFlags = ports.Flags{}
e.boundDest = tcpip.FullAddress{}
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
+
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -1279,11 +1272,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-func (e *endpoint) LastError() *tcpip.Error {
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) hardErrorLocked() *tcpip.Error {
+ err := e.hardError
+ e.hardError = nil
+ return err
+}
+
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) lastErrorLocked() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1291,6 +1293,16 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// LastError implements tcpip.Endpoint.LastError.
+func (e *endpoint) LastError() *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return e.lastErrorLocked()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1312,9 +1324,11 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
- he := e.HardError
if s == StateError {
- return buffer.View{}, tcpip.ControlMessages{}, he
+ if err := e.hardErrorLocked(); err != nil {
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
@@ -1370,9 +1384,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+ // The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
- return 0, e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return 0, err
+ }
+ return 0, tcpip.ErrClosedForSend
case !s.connecting() && !s.connected():
return 0, tcpip.ErrClosedForSend
case s.connecting():
@@ -1478,7 +1496,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1486,10 +1504,10 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.HardError
+ return 0, e.hardErrorLocked()
}
e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ return 0, tcpip.ErrInvalidEndpointState
}
e.rcvListMu.Lock()
@@ -1498,9 +1516,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return 0, tcpip.ErrClosedForReceive
}
- return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return 0, tcpip.ErrWouldBlock
}
// Make a copy of vec so we can modify the slide headers.
@@ -1515,7 +1533,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
for len(v) > 0 {
if len(vec) == 0 {
- return num, tcpip.ControlMessages{}, nil
+ return num, nil
}
if len(vec[0]) == 0 {
vec = vec[1:]
@@ -1530,7 +1548,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
}
}
- return num, tcpip.ControlMessages{}, nil
+ return num, nil
}
// selectWindowLocked returns the new window without checking for shrinking or scaling
@@ -1602,72 +1620,39 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
return false, false
}
-// SetSockOptBool sets a socket option.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
-
- case tcpip.CorkOption:
- e.LockUser()
- if !v {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- e.UnlockUser()
-
- case tcpip.DelayOption:
- if v {
- atomic.StoreUint32(&e.delay, 1)
- } else {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- }
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- e.keepalive.enabled = v
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
-
- case tcpip.QuickAckOption:
- o := uint32(1)
- if v {
- o = 0
- }
- atomic.StoreUint32(&e.slowAck, o)
-
- case tcpip.ReuseAddressOption:
- e.LockUser()
- e.portFlags.TupleOnly = v
- e.UnlockUser()
-
- case tcpip.ReusePortOption:
- e.LockUser()
- e.portFlags.LoadBalanced = v
- e.UnlockUser()
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.LockUser()
+ e.portFlags.TupleOnly = v
+ e.UnlockUser()
+}
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.LockUser()
+ e.portFlags.LoadBalanced = v
+ e.UnlockUser()
+}
- // We only allow this to be set when we're in the initial state.
- if e.EndpointState() != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
+// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet.
+func (e *endpoint) OnKeepAliveSet(v bool) {
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+}
- e.LockUser()
- e.v6only = v
- e.UnlockUser()
+// OnDelayOptionSet implements tcpip.SocketOptionsHandler.OnDelayOptionSet.
+func (e *endpoint) OnDelayOptionSet(v bool) {
+ if !v {
+ // Handle delayed data.
+ e.sndWaker.Assert()
}
+}
- return nil
+// OnCorkOptionSet implements tcpip.SocketOptionsHandler.OnCorkOptionSet.
+func (e *endpoint) OnCorkOptionSet(v bool) {
+ if !v {
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ }
}
// SetSockOptInt sets a socket option.
@@ -1851,9 +1836,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
-
case *tcpip.TCPUserTimeoutOption:
e.LockUser()
e.userTimeout = time.Duration(*v)
@@ -1922,11 +1904,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.LockUser()
- e.linger = *v
- e.UnlockUser()
-
default:
return nil
}
@@ -1949,67 +1926,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
-
- case tcpip.CorkOption:
- return atomic.LoadUint32(&e.cork) != 0, nil
-
- case tcpip.DelayOption:
- return atomic.LoadUint32(&e.delay) != 0, nil
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- v := e.keepalive.enabled
- e.keepalive.Unlock()
-
- return v, nil
-
- case tcpip.QuickAckOption:
- v := atomic.LoadUint32(&e.slowAck) == 0
- return v, nil
-
- case tcpip.ReuseAddressOption:
- e.LockUser()
- v := e.portFlags.TupleOnly
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.LockUser()
- v := e.portFlags.LoadBalanced
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.LockUser()
- v := e.v6only
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.MulticastLoopOption:
- return true, nil
-
- case tcpip.AcceptConnOption:
- e.LockUser()
- defer e.UnlockUser()
-
- return e.EndpointState() == StateListen, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -2120,10 +2036,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
e.UnlockUser()
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
- *o = 1
-
case *tcpip.CongestionControlOption:
e.LockUser()
*o = e.cc
@@ -2152,11 +2064,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
Port: port,
}
- case *tcpip.LingerOption:
- e.LockUser()
- *o = e.linger
- e.UnlockUser()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2166,7 +2073,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -2243,7 +2150,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return tcpip.ErrConnectionAborted
default:
return tcpip.ErrInvalidEndpointState
@@ -2417,7 +2327,7 @@ func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
// Call cleanupLocked to free up any reservations.
e.cleanupLocked()
@@ -2697,7 +2607,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// v6only set to false.
if netProto == header.IPv6ProtocolNumber {
stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
- alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4
if alsoBindToV4 {
netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
@@ -2782,7 +2692,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
- // land at the Dispatcher which then can either delivery using the
+ // land at the Dispatcher which then can either deliver using the
// worker go routine or directly do the invoke the tcp processing inline
// based on the state of the endpoint.
}
@@ -3079,6 +2989,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
Ssthresh: e.snd.sndSsthresh,
SndCAAckCount: e.snd.sndCAAckCount,
Outstanding: e.snd.outstanding,
+ SackedOut: e.snd.sackedOut,
SndWnd: e.snd.sndWnd,
SndUna: e.snd.sndUna,
SndNxt: e.snd.sndNxt,
@@ -3161,7 +3072,7 @@ func (e *endpoint) State() uint32 {
func (e *endpoint) Info() tcpip.EndpointInfo {
e.LockUser()
// Make a copy of the endpoint info.
- ret := e.EndpointInfo
+ ret := e.TransportEndpointInfo
e.UnlockUser()
return &ret
}
@@ -3187,6 +3098,7 @@ func (e *endpoint) Wait() {
}
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index bb901c0f8..ba67176b5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -321,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
}
// saveHardError is invoked by stateify.
-func (e *EndpointInfo) saveHardError() string {
- if e.HardError == nil {
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
return ""
}
- return e.HardError.String()
+ return e.hardError.String()
}
// loadHardError is invoked by stateify.
-func (e *EndpointInfo) loadHardError(s string) {
+func (e *endpoint) loadHardError(s string) {
if s == "" {
return
}
- e.HardError = tcpip.StringToError(s)
+ e.hardError = tcpip.StringToError(s)
}
// saveMeasureTime is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2329aca4b..672159eed 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
ttl = route.DefaultTTL()
}
- return sendTCP(&route, tcpFields{
+ return sendTCP(route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 8e0b7c843..405a6dce7 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -16,6 +16,7 @@ package tcp
import (
"container/heap"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -48,6 +49,10 @@ type receiver struct {
rcvWndScale uint8
+ // prevBufused is the snapshot of endpoint rcvBufUsed taken when we
+ // advertise a receive window.
+ prevBufUsed int
+
closed bool
// pendingRcvdSegments is bounded by the receive buffer size of the
@@ -80,9 +85,9 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// outgoing packets, we should use what we have advertised for acceptability
// test.
scaledWindowSize := r.rcvWnd >> r.rcvWndScale
- if scaledWindowSize > 0xffff {
+ if scaledWindowSize > math.MaxUint16 {
// This is what we actually put in the Window field.
- scaledWindowSize = 0xffff
+ scaledWindowSize = math.MaxUint16
}
advertisedWindowSize := scaledWindowSize << r.rcvWndScale
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
@@ -106,6 +111,34 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) {
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
newWnd := r.ep.selectWindow()
curWnd := r.currentWindow()
+ unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt))
+ bufUsed := r.ep.receiveBufferUsed()
+
+ // Grow the right edge of the window only for payloads larger than the
+ // the segment overhead OR if the application is actively consuming data.
+ //
+ // Avoiding growing the right edge otherwise, addresses a situation below:
+ // An application has been slow in reading data and we have burst of
+ // incoming segments lengths < segment overhead. Here, our available free
+ // memory would reduce drastically when compared to the advertised receive
+ // window.
+ //
+ // For example: With incoming 512 bytes segments, segment overhead of
+ // 552 bytes (at the time of writing this comment), with receive window
+ // starting from 1MB and with rcvAdvWndScale being 1, buffer would reach 0
+ // when the curWnd is still 19436 bytes, because for every incoming segment
+ // newWnd would reduce by (552+512) >> rcvAdvWndScale (current value 1),
+ // while curWnd would reduce by 512 bytes.
+ // Such a situation causes us to keep tail dropping the incoming segments
+ // and never advertise zero receive window to the peer.
+ //
+ // Linux does a similar check for minimal sk_buff size (128):
+ // https://github.com/torvalds/linux/blob/d5beb3140f91b1c8a3d41b14d729aefa4dcc58bc/net/ipv4/tcp_input.c#L783
+ //
+ // Also, if the application is reading the data, we keep growing the right
+ // edge, as we are still advertising a window that we think can be serviced.
+ toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed
+
// Update rcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
@@ -115,7 +148,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// rcvWUP rcvNxt rcvAcc new rcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
// If the new window moves the right edge, then update rcvAcc.
r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
} else {
@@ -130,11 +163,22 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// receiver's estimated RTT.
r.rcvWnd = newWnd
r.rcvWUP = r.rcvNxt
+ r.prevBufUsed = bufUsed
scaledWnd := r.rcvWnd >> r.rcvWndScale
if scaledWnd == 0 {
// Increment a metric if we are advertising an actual zero window.
r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
+
+ // If we started off with a window larger than what can he held in
+ // the 16bit window field, we ceil the value to the max value.
+ if scaledWnd > math.MaxUint16 {
+ scaledWnd = seqnum.Size(math.MaxUint16)
+
+ // Ensure that the stashed receive window always reflects what
+ // is being advertised.
+ r.rcvWnd = scaledWnd << r.rcvWndScale
+ }
return r.rcvNxt, scaledWnd
}
diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go
new file mode 100644
index 000000000..2aa708e97
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/reno_recovery.go
@@ -0,0 +1,67 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoRecovery stores the variables related to TCP Reno loss recovery
+// algorithm.
+//
+// +stateify savable
+type renoRecovery struct {
+ s *sender
+}
+
+func newRenoRecovery(s *sender) *renoRecovery {
+ return &renoRecovery{s: s}
+}
+
+func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
+ ack := rcvdSeg.ackNumber
+ snd := rr.s
+
+ // We are in fast recovery mode. Ignore the ack if it's out of range.
+ if !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ return
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window {
+ return
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if !fastRetransmit && ack == snd.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if snd.sndCwnd < snd.fr.maxCwnd {
+ snd.sndCwnd++
+ }
+ return
+ }
+
+ // A partial ack was received. Retransmit this packet and remember it
+ // so that we don't retransmit it again.
+ //
+ // We don't inflate the window because we're putting the same packet
+ // back onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ snd.fr.first = ack
+ snd.dupAckCount = 0
+ snd.resendSegment()
+}
diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go
new file mode 100644
index 000000000..7e813fa96
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_recovery.go
@@ -0,0 +1,120 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+
+// sackRecovery stores the variables related to TCP SACK loss recovery
+// algorithm.
+//
+// +stateify savable
+type sackRecovery struct {
+ s *sender
+}
+
+func newSACKRecovery(s *sender) *sackRecovery {
+ return &sackRecovery{s: s}
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ snd := sr.s
+ snd.SetPipe()
+
+ if smss := int(snd.ep.scoreboard.SMSS()); limit > smss {
+ // Cap segment size limit to s.smss as SACK recovery requires
+ // that all retransmissions or new segments send during recovery
+ // be of <= SMSS.
+ limit = smss
+ }
+
+ nextSegHint := snd.writeList.Front()
+ for snd.outstanding < snd.sndCwnd {
+ var nextSeg *segment
+ var rescueRtx bool
+ nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint)
+ if nextSeg == nil {
+ return dataSent
+ }
+ if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
+ // New data being sent.
+
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ //
+ // We pass s.smss as the limit as the Step 2) requires that
+ // new data sent should be of size s.smss or less.
+ if sent := snd.maybeSendSegment(nextSeg, limit, end); !sent {
+ return dataSent
+ }
+ dataSent = true
+ snd.outstanding++
+ snd.writeNext = nextSeg.Next()
+ continue
+ }
+
+ // Now handle the retransmission case where we matched either step 1,3 or 4
+ // of the NextSeg algorithm.
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ snd.outstanding++
+ dataSent = true
+ snd.sendSegment(nextSeg)
+
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if rescueRtx {
+ // We do the last part of rule (4) of NextSeg here to update
+ // RescueRxt as until this point we don't know if we are going
+ // to use the rescue transmission.
+ snd.fr.rescueRxt = snd.fr.last
+ } else {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ snd.fr.highRxt = segEnd - 1
+ }
+ }
+ return dataSent
+}
+
+func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
+ snd := sr.s
+ if fastRetransmit {
+ snd.resendSegment()
+ }
+
+ // We are in fast recovery mode. Ignore the ack if it's out of range.
+ if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ return
+ }
+
+ // RFC 6675 recovery algorithm step C 1-5.
+ end := snd.sndUna.Add(snd.sndWnd)
+ dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end)
+ snd.postXmit(dataSent)
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 2091989cc..5ef73ec74 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -204,7 +204,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return segSize + s.data.Size()
+ return SegSize + s.data.Size()
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
index 0ab7b8f56..392ff0859 100644
--- a/pkg/tcpip/transport/tcp/segment_unsafe.go
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -19,5 +19,6 @@ import (
)
const (
- segSize = int(unsafe.Sizeof(segment{}))
+ // SegSize is the minimal size of the segment overhead.
+ SegSize = int(unsafe.Sizeof(segment{}))
)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 0e0fdf14c..cc991aba6 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -18,7 +18,6 @@ import (
"fmt"
"math"
"sort"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -92,6 +91,17 @@ type congestionControl interface {
PostRecovery()
}
+// lossRecovery is an interface that must be implemented by any supported
+// loss recovery algorithm.
+type lossRecovery interface {
+ // DoRecovery is invoked when loss is detected and segments need
+ // to be retransmitted. The cumulative or selective ACK is passed along
+ // with the flag which identifies whether the connection entered fast
+ // retransmit with this ACK and to retransmit the first unacknowledged
+ // segment.
+ DoRecovery(rcvdSeg *segment, fastRetransmit bool)
+}
+
// sender holds the state necessary to send TCP segments.
//
// +stateify savable
@@ -108,6 +118,9 @@ type sender struct {
// fr holds state related to fast recovery.
fr fastRecovery
+ // lr is the loss recovery algorithm used by the sender.
+ lr lossRecovery
+
// sndCwnd is the congestion window, in packets.
sndCwnd int
@@ -124,6 +137,9 @@ type sender struct {
// that have been sent but not yet acknowledged.
outstanding int
+ // sackedOut is the number of packets which are selectively acked.
+ sackedOut int
+
// sndWnd is the send window size.
sndWnd seqnum.Size
@@ -276,6 +292,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.cc = s.initCongestionControl(ep.cc)
+ s.lr = s.initLossRecovery()
+
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
if sndWndScale > 0 {
@@ -330,6 +348,14 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon
}
}
+// initLossRecovery initiates the loss recovery algorithm for the sender.
+func (s *sender) initLossRecovery() lossRecovery {
+ if s.ep.sackPermitted {
+ return newSACKRecovery(s)
+ }
+ return newRenoRecovery(s)
+}
+
// updateMaxPayloadSize updates the maximum payload size based on the given
// MTU. If this is in response to "packet too big" control packets (indicated
// by the count argument), it also reduces the number of outstanding packets and
@@ -349,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
+ oldMSS := s.maxPayloadSize
s.maxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
@@ -371,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
+ nextSeg := s.writeNext
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
@@ -378,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
break
}
- if seg.data.Size() > m {
+ if nextSeg == s.writeNext && seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
- s.writeNext = seg
- break
+ nextSeg = seg
+ }
+
+ if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Update sackedOut for new maximum payload size.
+ s.sackedOut -= s.pCount(seg, oldMSS)
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
+ s.writeNext = nextSeg
s.sendData()
}
@@ -550,7 +584,7 @@ func (s *sender) retransmitTimerExpired() bool {
// We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
// has already been updated when entered fast-recovery.
- s.leaveFastRecovery()
+ s.leaveRecovery()
}
s.state = RTORecovery
@@ -606,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool {
// pCount returns the number of packets in the segment. Due to GSO, a segment
// can be composed of multiple packets.
-func (s *sender) pCount(seg *segment) int {
+func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
size := seg.data.Size()
if size == 0 {
return 1
}
- return (size-1)/s.maxPayloadSize + 1
+ return (size-1)/maxPayloadSize + 1
}
// splitSeg splits a given segment at the size specified and inserts the
@@ -789,7 +823,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
if !nextTooBig && seg.data.Size() < available {
// Segment is not full.
- if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ if s.outstanding > 0 && s.ep.ops.GetDelayOption() {
// Nagle's algorithm. From Wikipedia:
// Nagle's algorithm works by
// combining a number of small
@@ -808,7 +842,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// send space and MSS.
// TODO(gvisor.dev/issue/2833): Drain the held segments after a
// timeout.
- if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
+ if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() {
return false
}
}
@@ -913,79 +947,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
return true
}
-// handleSACKRecovery implements the loss recovery phase as described in RFC6675
-// section 5, step C.
-func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
- s.SetPipe()
-
- if smss := int(s.ep.scoreboard.SMSS()); limit > smss {
- // Cap segment size limit to s.smss as SACK recovery requires
- // that all retransmissions or new segments send during recovery
- // be of <= SMSS.
- limit = smss
- }
-
- nextSegHint := s.writeList.Front()
- for s.outstanding < s.sndCwnd {
- var nextSeg *segment
- var rescueRtx bool
- nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint)
- if nextSeg == nil {
- return dataSent
- }
- if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
- // New data being sent.
-
- // Step C.3 described below is handled by
- // maybeSendSegment which increments sndNxt when
- // a segment is transmitted.
- //
- // Step C.3 "If any of the data octets sent in
- // (C.1) are above HighData, HighData must be
- // updated to reflect the transmission of
- // previously unsent data."
- //
- // We pass s.smss as the limit as the Step 2) requires that
- // new data sent should be of size s.smss or less.
- if sent := s.maybeSendSegment(nextSeg, limit, end); !sent {
- return dataSent
- }
- dataSent = true
- s.outstanding++
- s.writeNext = nextSeg.Next()
- continue
- }
-
- // Now handle the retransmission case where we matched either step 1,3 or 4
- // of the NextSeg algorithm.
- // RFC 6675, Step C.4.
- //
- // "The estimate of the amount of data outstanding in the network
- // must be updated by incrementing pipe by the number of octets
- // transmitted in (C.1)."
- s.outstanding++
- dataSent = true
- s.sendSegment(nextSeg)
-
- segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
- if rescueRtx {
- // We do the last part of rule (4) of NextSeg here to update
- // RescueRxt as until this point we don't know if we are going
- // to use the rescue transmission.
- s.fr.rescueRxt = s.fr.last
- } else {
- // RFC 6675, Step C.2
- //
- // "If any of the data octets sent in (C.1) are below
- // HighData, HighRxt MUST be set to the highest sequence
- // number of the retransmitted segment unless NextSeg ()
- // rule (4) was invoked for this retransmission."
- s.fr.highRxt = segEnd - 1
- }
- }
- return dataSent
-}
-
func (s *sender) sendZeroWindowProbe() {
ack, win := s.ep.rcv.getSendParams()
s.unackZeroWindowProbes++
@@ -1014,6 +975,30 @@ func (s *sender) disableZeroWindowProbing() {
s.resendTimer.disable()
}
+func (s *sender) postXmit(dataSent bool) {
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
+
+ // If the sender has advertized zero receive window and we have
+ // data to be sent out, start zero window probing to query the
+ // the remote for it's receive window size.
+ if s.writeNext != nil && s.sndWnd == 0 {
+ s.enableZeroWindowProbing()
+ }
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
+}
+
// sendData sends new data segments. It is called when data becomes available or
// when the send window opens up.
func (s *sender) sendData() {
@@ -1034,55 +1019,29 @@ func (s *sender) sendData() {
}
var dataSent bool
-
- // RFC 6675 recovery algorithm step C 1-5.
- if s.fr.active && s.ep.sackPermitted {
- dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
- } else {
- for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
- cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
- if cwndLimit < limit {
- limit = cwndLimit
- }
- if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- // Move writeNext along so that we don't try and scan data that
- // has already been SACKED.
- s.writeNext = seg.Next()
- continue
- }
- if sent := s.maybeSendSegment(seg, limit, end); !sent {
- break
- }
- dataSent = true
- s.outstanding += s.pCount(seg)
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Move writeNext along so that we don't try and scan data that
+ // has already been SACKED.
s.writeNext = seg.Next()
+ continue
}
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding += s.pCount(seg, s.maxPayloadSize)
+ s.writeNext = seg.Next()
}
- if dataSent {
- // We sent data, so we should stop the keepalive timer to ensure
- // that no keepalives are sent while there is pending data.
- s.ep.disableKeepaliveTimer()
- }
-
- // If the sender has advertized zero receive window and we have
- // data to be sent out, start zero window probing to query the
- // the remote for it's receive window size.
- if s.writeNext != nil && s.sndWnd == 0 {
- s.enableZeroWindowProbing()
- }
-
- // Enable the timer if we have pending data and it's not enabled yet.
- if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
- s.resendTimer.enable(s.rto)
- }
- // If we have no more pending data, start the keepalive timer.
- if s.sndUna == s.sndNxt {
- s.ep.resetKeepaliveTimer(false)
- }
+ s.postXmit(dataSent)
}
-func (s *sender) enterFastRecovery() {
+func (s *sender) enterRecovery() {
s.fr.active = true
// Save state to reflect we're now in fast recovery.
//
@@ -1090,6 +1049,7 @@ func (s *sender) enterFastRecovery() {
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
+ s.sackedOut = 0
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
@@ -1104,7 +1064,7 @@ func (s *sender) enterFastRecovery() {
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
-func (s *sender) leaveFastRecovery() {
+func (s *sender) leaveRecovery() {
s.fr.active = false
s.fr.maxCwnd = 0
s.dupAckCount = 0
@@ -1115,57 +1075,6 @@ func (s *sender) leaveFastRecovery() {
s.cc.PostRecovery()
}
-func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
- ack := seg.ackNumber
- // We are in fast recovery mode. Ignore the ack if it's out of
- // range.
- if !ack.InRange(s.sndUna, s.sndNxt+1) {
- return false
- }
-
- // Leave fast recovery if it acknowledges all the data covered by
- // this fast recovery session.
- if s.fr.last.LessThan(ack) {
- s.leaveFastRecovery()
- return false
- }
-
- if s.ep.sackPermitted {
- // When SACK is enabled we let retransmission be governed by
- // the SACK logic.
- return false
- }
-
- // Don't count this as a duplicate if it is carrying data or
- // updating the window.
- if seg.logicalLen() != 0 || s.sndWnd != seg.window {
- return false
- }
-
- // Inflate the congestion window if we're getting duplicate acks
- // for the packet we retransmitted.
- if ack == s.fr.first {
- // We received a dup, inflate the congestion window by 1 packet
- // if we're not at the max yet. Only inflate the window if
- // regular FastRecovery is in use, RFC6675 does not require
- // inflating cwnd on duplicate ACKs.
- if s.sndCwnd < s.fr.maxCwnd {
- s.sndCwnd++
- }
- return false
- }
-
- // A partial ack was received. Retransmit this packet and
- // remember it so that we don't retransmit it again. We don't
- // inflate the window because we're putting the same packet back
- // onto the wire.
- //
- // N.B. The retransmit timer will be reset by the caller.
- s.fr.first = ack
- s.dupAckCount = 0
- return true
-}
-
// isAssignedSequenceNumber relies on the fact that we only set flags once a
// sequencenumber is assigned and that is only done right before we send the
// segment. As a result any segment that has a non-zero flag has a valid
@@ -1228,14 +1137,11 @@ func (s *sender) SetPipe() {
s.outstanding = pipe
}
-// checkDuplicateAck is called when an ack is received. It manages the state
-// related to duplicate acks and determines if a retransmit is needed according
-// to the rules in RFC 6582 (NewReno).
-func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+// detectLoss is called when an ack is received and returns whether a loss is
+// detected. It manages the state related to duplicate acks and determines if
+// a retransmit is needed according to the rules in RFC 6582 (NewReno).
+func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
ack := seg.ackNumber
- if s.fr.active {
- return s.handleFastRecovery(seg)
- }
// We're not in fast recovery yet. A segment is considered a duplicate
// only if it doesn't carry any data and doesn't update the send window,
@@ -1266,14 +1172,14 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
//
// We only do the check here, the incrementing of last to the highest
- // sequence number transmitted till now is done when enterFastRecovery
+ // sequence number transmitted till now is done when enterRecovery
// is invoked.
if !s.fr.last.LessThan(seg.ackNumber) {
s.dupAckCount = 0
return false
}
s.cc.HandleNDupAcks()
- s.enterFastRecovery()
+ s.enterRecovery()
s.dupAckCount = 0
return true
}
@@ -1313,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
s.rc.detectReorder(seg)
seg.acked = true
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
seg = seg.Next()
}
@@ -1415,14 +1322,23 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.SetPipe()
}
- // Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(rcvdSeg)
+ ack := rcvdSeg.ackNumber
+ fastRetransmit := false
+ // Do not leave fast recovery, if the ACK is out of range.
+ if s.fr.active {
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) {
+ s.leaveRecovery()
+ }
+ } else {
+ // Detect loss by counting the duplicates and enter recovery.
+ fastRetransmit = s.detectLoss(rcvdSeg)
+ }
// Stash away the current window size.
s.sndWnd = rcvdSeg.window
- ack := rcvdSeg.ackNumber
-
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
@@ -1477,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg)
+ prevCount := s.pCount(seg, s.maxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg)
+ s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
break
}
@@ -1496,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.writeList.Remove(seg)
- // If SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg)
+ s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ } else {
+ s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
@@ -1539,19 +1457,24 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.resendTimer.disable()
}
}
+
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
- if rtx {
- s.resendSegment()
+ if s.fr.active {
+ s.lr.DoRecovery(rcvdSeg, fastRetransmit)
+ // When SACK is enabled data sending is governed by steps in
+ // RFC 6675 Section 5 recovery steps A-C.
+ // See: https://tools.ietf.org/html/rfc6675#section-5.
+ if s.ep.sackPermitted {
+ return
+ }
}
// Send more data now that some of the pending data has been ack'd, or
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
- s.sendData()
- }
+ s.sendData()
}
// sendSegment sends the specified segment.
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ef7f5719f..faf0c0ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) {
expected++
}
}
+
+// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK
+// is received.
+func TestSACKUpdateSackedOut(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ ackNum := 0
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.SackedOut is what we expect.
+ if state.Sender.SackedOut != 2 && ackNum == 0 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
+ }
+
+ if state.Sender.SackedOut != 0 && ackNum == 1 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
+ }
+ if ackNum > 0 {
+ close(probeDone)
+ }
+ ackNum++
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ sendAndReceive(t, c, 8)
+
+ // ACK for [3-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
+ bytesRead := 2 * maxPayload
+ end := start.Add(seqnum.Size(bytesRead))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ bytesRead += 3 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9f0fb41e3..351a5e4f5 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.LastError(); err != tcpip.ErrAborted {
- t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
- }
// Call Connect again to retreive the handshake failure status
// and stats updates.
@@ -267,7 +264,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) {
}
// Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
}
@@ -1935,6 +1932,84 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+// Test the stack receive window advertisement on receiving segments smaller than
+// segment overhead. It tests for the right edge of the window to not grow when
+// the endpoint is not being read from.
+func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize,
+ Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
+ }
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+
+ c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Bump up the receive buffer size such that, when the receive window grows,
+ // the scaled window exceeds maxUint16.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err)
+ }
+
+ // Keep the payload size < segment overhead and such that it is a multiple
+ // of the window scaled value. This enables the test to perform equality
+ // checks on the incoming receive window.
+ payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale))
+ payloadLen := seqnum.Size(len(payload))
+ iss := seqnum.Value(789)
+ seqNum := iss.Add(1)
+
+ // Send payload to the endpoint and return the advertised receive window
+ // from the endpoint.
+ getIncomingRcvWnd := func() uint32 {
+ c.SendPacket(payload, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ Flags: header.TCPFlagAck,
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(payloadLen)
+
+ pkt := c.GetPacket()
+ return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
+ }
+
+ // Read the advertised receive window with the ACK for payload.
+ rcvWnd := getIncomingRcvWnd()
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Read the data so that the subsequent ACK from the endpoint
+ // grows the right edge of the window.
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("got Read(nil) = %s", err)
+ }
+
+ // Check if we have received max uint16 as our advertised
+ // scaled window now after a read above.
+ maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
+ if got, want := getIncomingRcvWnd(), maxRcv; got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+}
+
func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -2532,10 +2607,10 @@ func TestSegmentMerging(t *testing.T) {
{
"cork",
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, true)
+ ep.SocketOptions().SetCorkOption(true)
},
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, false)
+ ep.SocketOptions().SetCorkOption(false)
},
},
}
@@ -2627,7 +2702,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -2675,7 +2750,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -2708,7 +2783,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOptBool(tcpip.DelayOption, false)
+ c.EP.SocketOptions().SetDelayOption(false)
// Check that data is received.
second := c.GetPacket()
@@ -2745,8 +2820,8 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }},
}
for _, test := range tests {
@@ -3198,6 +3273,11 @@ loop:
case tcpip.ErrWouldBlock:
select {
case <-ch:
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+ break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
@@ -3207,14 +3287,10 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
- }
+
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
-
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
@@ -4150,7 +4226,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Check that peek works.
peekBuf := make([]byte, 10)
- n, _, err := c.EP.Peek([][]byte{peekBuf})
+ n, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
@@ -4176,7 +4252,7 @@ func TestReadAfterClosedState(t *testing.T) {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
- if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
+ if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -4193,9 +4269,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4205,9 +4279,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4218,9 +4290,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4233,9 +4303,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4246,9 +4314,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4261,9 +4327,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4656,13 +4720,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
case "dual":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(false)
default:
t.Fatalf("unknown network: '%s'", network)
}
@@ -4998,9 +5058,7 @@ func TestKeepalive(t *testing.T) {
if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
- t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
- }
+ c.EP.SocketOptions().SetKeepAlive(true)
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -6118,10 +6176,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Introduce a 25ms latency by delaying the first byte.
latency := 25 * time.Millisecond
time.Sleep(latency)
- rawEP.SendPacketWithTS([]byte{1}, tsVal)
+ // Send an initial payload with atleast segment overhead size. The receive
+ // window would not grow for smaller segments.
+ rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal)
pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
+
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
@@ -6394,10 +6455,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T
if err != nil {
t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
- gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
- }
+ gotDelayOption := ep.SocketOptions().GetDelayOption()
if gotDelayOption != wantDelayOption {
t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
}
@@ -7250,9 +7308,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
- t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
- }
+ c.EP.SocketOptions().SetKeepAlive(true)
// Set userTimeout to be the duration to be 1 keepalive
// probes. Which means that after the first probe is sent
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e6aa4fc4b..ee55f030c 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -592,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
+ c.EP.SocketOptions().SetV6Only(v6only)
}
// GetV6Packet reads a single packet from the link layer endpoint of the context
@@ -637,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- NextHeader: uint8(tcp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
+ TransportProtocol: tcp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c78549424..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -56,6 +56,8 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 57976d4e3..763d1d654 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,8 +16,8 @@ package udp
import (
"fmt"
+ "sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -30,10 +30,11 @@ import (
// +stateify savable
type udpPacket struct {
udpPacketEntry
- senderAddress tcpip.FullAddress
- packetInfo tcpip.IPPacketInfo
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ senderAddress tcpip.FullAddress
+ destinationAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
@@ -77,6 +78,7 @@ func (s EndpointState) String() string {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
@@ -94,21 +96,20 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
dstPort uint16
- v6only bool
ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
- multicastLoop bool
portFlags ports.Flags
bindToDevice tcpip.NICID
- noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -122,17 +123,6 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
- // receiveTOS determines if the incoming IPv4 TOS header field is passed
- // as ancillary data to ControlMessages on Read.
- receiveTOS bool
-
- // receiveTClass determines if the incoming IPv6 TClass header field is
- // passed as ancillary data to ControlMessages on Read.
- receiveTClass bool
-
- // receiveIPPacketInfo determines if the packet info is returned by Read.
- receiveIPPacketInfo bool
-
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -154,9 +144,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -188,13 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
+ e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -210,6 +198,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState() returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
// UniqueID implements stack.TransportEndpoint.UniqueID.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -235,7 +237,7 @@ func (e *endpoint) Close() {
e.mu.Lock()
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
+ switch e.EndpointState() {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
@@ -258,10 +260,13 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
- e.state = StateClosed
+ e.setEndpointState(StateClosed)
e.mu.Unlock()
@@ -303,24 +308,23 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
HasTimestamp: true,
Timestamp: p.timestamp,
}
- e.mu.RLock()
- receiveTOS := e.receiveTOS
- receiveTClass := e.receiveTClass
- receiveIPPacketInfo := e.receiveIPPacketInfo
- e.mu.RUnlock()
- if receiveTOS {
+ if e.ops.GetReceiveTOS() {
cm.HasTOS = true
cm.TOS = p.tos
}
- if receiveTClass {
+ if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
- if receiveIPPacketInfo {
+ if e.ops.GetReceivePacketInfo() {
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
}
+ if e.ops.GetReceiveOriginalDstAddress() {
+ cm.HasOriginalDstAddress = true
+ cm.OriginalDstAddress = p.destinationAddress
+ }
return p.data.ToView(), cm, nil
}
@@ -330,7 +334,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
//
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateConnected:
return false, nil
@@ -352,7 +356,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return true, nil
}
@@ -367,9 +371,9 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
- if isBroadcastOrMulticast(localAddr) {
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
}
@@ -384,9 +388,9 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
if err != nil {
- return stack.Route{}, 0, err
+ return nil, 0, err
}
return r, nicID, nil
}
@@ -429,7 +433,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
to := opts.To
e.mu.RLock()
- defer e.mu.RUnlock()
+ lockReleased := false
+ defer func() {
+ if lockReleased {
+ return
+ }
+ e.mu.RUnlock()
+ }()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
@@ -448,36 +458,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
- var dstPort uint16
- if to == nil {
- route = &e.route
- dstPort = e.dstPort
- resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
- // Promote lock to exclusive if using a shared route, given that it may
- // need to change in Route.Resolve() call below.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- if err == nil && route.IsResolutionRequired() {
- ch, err = route.Resolve(waker)
- }
-
- e.mu.Unlock()
- e.mu.RLock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- return
- }
- } else {
+ route := e.route
+ dstPort := e.dstPort
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -505,9 +488,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
dstPort = dst.Port
- resolve = route.Resolve
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
@@ -515,7 +497,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if route.IsResolutionRequired() {
- if ch, err := resolve(nil); err != nil {
+ if ch, err := route.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -541,77 +523,46 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
+ localPort := e.ID.LocalPort
+ sendTOS := e.sendTOS
+ owner := e.owner
+ noChecksum := e.SocketOptions().GetNoChecksum()
+ lockReleased = true
+ e.mu.RUnlock()
+
+ // Do not hold lock when sending as loopback is synchronous and if the UDP
+ // datagram ends up generating an ICMP response then it can result in a
+ // deadlock where the ICMP response handling ends up acquiring this endpoint's
+ // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
+ // deadlock if another caller is trying to acquire e.mu in exclusive mode w/
+ // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
+ // lock can be eventually acquired.
+ //
+ // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
+ // locking is prohibited.
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.MulticastLoopOption:
- e.mu.Lock()
- e.multicastLoop = v
- e.mu.Unlock()
-
- case tcpip.NoChecksumOption:
- e.mu.Lock()
- e.noChecksum = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTOSOption:
- e.mu.Lock()
- e.receiveTOS = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrNotSupported
- }
-
- e.mu.Lock()
- e.receiveTClass = v
- e.mu.Unlock()
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.Lock()
- e.receiveIPPacketInfo = v
- e.mu.Unlock()
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.portFlags.MostRecent = v
- e.mu.Unlock()
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.portFlags.LoadBalanced = v
- e.mu.Unlock()
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
+}
- e.v6only = v
- }
- return nil
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.LoadBalanced = v
+ e.mu.Unlock()
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
@@ -814,90 +765,10 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
}
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption:
- return false, nil
-
- case tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.NoChecksumOption:
- e.mu.RLock()
- v := e.noChecksum
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTOSOption:
- e.mu.RLock()
- v := e.receiveTOS
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrNotSupported
- }
-
- e.mu.RLock()
- v := e.receiveTClass
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.RLock()
- v := e.receiveIPPacketInfo
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReuseAddressOption:
- e.mu.RLock()
- v := e.portFlags.MostRecent
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.portFlags.LoadBalanced
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.RLock()
- v := e.v6only
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.AcceptConnOption:
- return false, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -972,11 +843,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
- case *tcpip.LingerOption:
- e.mu.RLock()
- *o = e.linger
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1036,7 +902,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -1048,7 +914,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return nil
}
var (
@@ -1071,7 +937,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = StateBound
+ e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
@@ -1079,14 +945,14 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- e.state = StateInitial
+ e.setEndpointState(StateInitial)
}
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
- e.route = stack.Route{}
+ e.route = nil
e.dstPort = 0
return nil
@@ -1104,7 +970,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
var localPort uint16
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateBound, StateConnected:
localPort = e.ID.LocalPort
@@ -1139,7 +1005,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -1147,7 +1013,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
@@ -1173,7 +1039,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
- e.state = StateConnected
+ e.setEndpointState(StateConnected)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1195,7 +1061,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != StateBound && e.state != StateConnected {
+ if state := e.EndpointState(); state != StateBound && state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -1246,7 +1112,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1259,7 +1125,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// wildcard (empty) address, and this is an IPv6 endpoint with v6only
// set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
@@ -1267,7 +1133,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
nicID := addr.NIC
- if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
// A local unicast address was specified, verify that it's valid.
nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicID == 0 {
@@ -1290,7 +1156,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1322,7 +1188,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
addr := e.ID.LocalAddress
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
addr = e.route.LocalAddress
}
@@ -1338,7 +1204,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1393,7 +1259,6 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1402,6 +1267,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
+ // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
+ // packets at "Parse" instead of when handling a packet.
+ pkt.Data.CapLength(int(hdr.PayloadLength()))
+
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1435,7 +1304,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: header.UDP(hdr).SourcePort(),
+ Port: hdr.SourcePort(),
+ },
+ destinationAddress: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: header.UDP(hdr).DestinationPort(),
},
}
packet.data = pkt.Data
@@ -1470,25 +1344,20 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
- e.mu.RLock()
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
e.lastErrorMu.Lock()
e.lastError = tcpip.ErrConnectionRefused
e.lastErrorMu.Unlock()
- e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventErr)
return
}
- e.mu.RUnlock()
}
}
// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
@@ -1508,14 +1377,16 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.Wait.
func (*endpoint) Wait() {}
-func isBroadcastOrMulticast(a tcpip.Address) bool {
- return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 858c99a45..13b72dc88 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != StateBound && e.state != StateConnected {
+ state := e.EndpointState()
+ if state != StateBound && state != StateConnected {
return
}
@@ -113,12 +114,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
+ if state == StateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound
+ } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 764ad0857..08980c298 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -32,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,6 +56,7 @@ const (
stackPort = 1234
testAddr = "\x0a\x00\x00\x02"
testPort = 4096
+ invalidPort = 8192
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
@@ -295,7 +298,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
})
}
@@ -360,9 +364,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetV6Only(true)
} else if flow.isBroadcast() {
c.ep.SocketOptions().SetBroadcast(true)
}
@@ -451,12 +453,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -972,7 +974,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
// provided.
func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, true, checkers...)
+ return testWriteAndVerifyInternal(c, flow, true, checkers...)
}
// testWriteWithoutDestination sends a packet of the given test flow from the
@@ -981,10 +983,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker
// checker functions provided.
func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, false, checkers...)
+ return testWriteAndVerifyInternal(c, flow, false, checkers...)
}
-func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
c.t.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
@@ -1006,6 +1008,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
c.checkEndpointWriteStats(1, epstats, err)
+ return payload
+}
+
+func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ payload := testWriteNoVerify(c, flow, setDest)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -1150,6 +1158,39 @@ func TestV4WriteOnConnected(t *testing.T) {
testWriteWithoutDestination(c, unicastV4)
}
+func TestWriteOnConnectedInvalidPort(t *testing.T) {
+ protocols := map[string]tcpip.NetworkProtocolNumber{
+ "ipv4": ipv4.ProtocolNumber,
+ "ipv6": ipv6.ProtocolNumber,
+ }
+ for name, pn := range protocols {
+ t.Run(name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(pn)
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
+ }
+ payload := buffer.View(newPayload())
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ if err != nil {
+ c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
+ }
+ if got, want := n, int64(len(payload)); got != want {
+ c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
+ }
+
+ if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused {
+ c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ }
+ })
+ }
+}
+
// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
// that is bound to a V4 multicast address.
func TestWriteOnBoundToV4Multicast(t *testing.T) {
@@ -1372,9 +1413,7 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
- t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
- }
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
NIC: 1,
@@ -1389,6 +1428,93 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
+func TestReadRecvOriginalDstAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedOriginalDstAddr tcpip.FullAddress
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%#v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
+ }
+ }
+
+ c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
+
+ testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1412,16 +1538,12 @@ func TestNoChecksum(t *testing.T) {
c.createEndpointForFlow(flow)
// Disable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(true)
// This option is effective on IPv4 only.
testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
// Enable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(false)
testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
})
}
@@ -1591,13 +1713,15 @@ func TestSetTClass(t *testing.T) {
}
func TestReceiveTosTClass(t *testing.T) {
+ const RcvTOSOpt = "ReceiveTosOption"
+ const RcvTClassOpt = "ReceiveTClassOption"
+
testCases := []struct {
- name string
- getReceiveOption tcpip.SockOptBool
- tests []testFlow
+ name string
+ tests []testFlow
}{
- {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
- {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
+ {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1606,29 +1730,32 @@ func TestReceiveTosTClass(t *testing.T) {
defer c.cleanup()
c.createEndpointForFlow(flow)
- option := testCase.getReceiveOption
name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ var optionGetter func() bool
+ var optionSetter func(bool)
+ switch name {
+ case RcvTOSOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTOS
+ optionSetter = c.ep.SocketOptions().SetReceiveTOS
+ case RcvTClassOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTClass
+ optionSetter = c.ep.SocketOptions().SetReceiveTClass
+ default:
+ t.Fatalf("unkown test variant: %s", name)
}
+
+ // Verify that setting and reading the option works.
+ v := optionGetter()
// Test for expected default value.
if v != false {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
}
want := true
- if err := c.ep.SetSockOptBool(option, want); err != nil {
- c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
- }
-
- got, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
- }
+ optionSetter(want)
+ got := optionGetter()
if got != want {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
}
@@ -1638,10 +1765,10 @@ func TestReceiveTosTClass(t *testing.T) {
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %s", err)
}
- switch option {
- case tcpip.ReceiveTClassOption:
+ switch name {
+ case RcvTClassOpt:
testRead(c, flow, checker.ReceiveTClass(testTOS))
- case tcpip.ReceiveTOSOption:
+ case RcvTOSOpt:
testRead(c, flow, checker.ReceiveTOS(testTOS))
default:
t.Fatalf("unknown test variant: %s", name)
@@ -1788,27 +1915,31 @@ func TestV4UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
// To work out the data size we need to simulate what the sender would
// have done. The wanted size is the total available minus the sum of
// the headers in the UDP AND ICMP packets, given that we know the test
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+ wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
}
// In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
// Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
-
+ payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
+
+ if got, want := len(origDgram), wantDgramLen; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantDgramLen))
+
+ if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
+ t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
}
})
}
@@ -1883,20 +2014,23 @@ func TestV6UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv6(hdr.Payload())
payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
}
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
// In case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ payloadIPHeader.SetPayloadLength(uint16(wantDgramLen))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize))
+ if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" {
+ t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff)
}
})
}
@@ -1955,12 +2089,12 @@ func TestShortHeader(t *testing.T) {
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -2409,3 +2543,67 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
})
}
}
+
+func TestReceiveShortLength(t *testing.T) {
+ flows := []testFlow{unicastV4, unicastV6}
+ for _, flow := range flows {
+ t.Run(flow.String(), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to wildcard.
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err)
+ }
+
+ payload := newPayload()
+ extraBytes := []byte{1, 2, 3, 4}
+ h := flow.header4Tuple(incoming)
+ var buf buffer.View
+ var proto tcpip.NetworkProtocolNumber
+
+ // Build packets with extra bytes not accounted for in the UDP length
+ // field.
+ var udp header.UDP
+ if flow.isV4() {
+ buf = c.buildV4Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv4(buf)
+ ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes)))
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ proto = ipv4.ProtocolNumber
+ udp = ip.Payload()
+ } else {
+ buf = c.buildV6Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv6(buf)
+ ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes)))
+ proto = ipv6.ProtocolNumber
+ udp = ip.Payload()
+ }
+
+ if diff := cmp.Diff(payload, udp.Payload()); diff != "" {
+ t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff)
+ }
+
+ c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Try to receive the data.
+ v, _, err := c.ep.Read(nil)
+ if err != nil {
+ t.Fatalf("c.ep.Read(nil): %s", err)
+ }
+
+ // Check the payload is read back without extra bytes.
+ if diff := cmp.Diff(buffer.View(payload), v); diff != "" {
+ t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
index 70945f234..e41769017 100644
--- a/pkg/test/criutil/criutil.go
+++ b/pkg/test/criutil/criutil.go
@@ -54,14 +54,20 @@ func ResolvePath(executable string) string {
}
}
+ // Favor /usr/local/bin, if it exists.
+ localBin := fmt.Sprintf("/usr/local/bin/%s", executable)
+ if _, err := os.Stat(localBin); err == nil {
+ return localBin
+ }
+
// Try to find via the path.
- guess, err := exec.LookPath(executable)
+ guess, _ := exec.LookPath(executable)
if err == nil {
return guess
}
- // Return a default path.
- return fmt.Sprintf("/usr/local/bin/%s", executable)
+ // Return a bare path; this generates a suitable error.
+ return executable
}
// NewCrictl returns a Crictl configured with a timeout and an endpoint over
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 64d17f661..2bf0a22ff 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -17,6 +17,7 @@ package dockerutil
import (
"bytes"
"context"
+ "errors"
"fmt"
"io/ioutil"
"net"
@@ -351,6 +352,9 @@ func (c *Container) SandboxPid(ctx context.Context) (int, error) {
return resp.ContainerJSONBase.State.Pid, nil
}
+// ErrNoIP indicates that no IP address is available.
+var ErrNoIP = errors.New("no IP available")
+
// FindIP returns the IP address of the container.
func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) {
resp, err := c.client.ContainerInspect(ctx, c.id)
@@ -365,7 +369,7 @@ func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) {
ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress)
}
if ip == nil {
- return net.IP{}, fmt.Errorf("invalid IP: %q", ip)
+ return net.IP{}, ErrNoIP
}
return ip, nil
}
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
index 4c739c9e9..bf968acec 100644
--- a/pkg/test/dockerutil/exec.go
+++ b/pkg/test/dockerutil/exec.go
@@ -77,11 +77,6 @@ func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Proc
return Process{}, fmt.Errorf("exec attach failed with err: %v", err)
}
- if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil {
- hijack.Close()
- return Process{}, fmt.Errorf("exec start failed with err: %v", err)
- }
-
return Process{
container: c,
execid: resp.ID,
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 49ab87c58..fdd416b5e 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -36,7 +36,6 @@ import (
"path/filepath"
"strconv"
"strings"
- "sync/atomic"
"syscall"
"testing"
"time"
@@ -49,7 +48,10 @@ import (
)
var (
- checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+ checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+ partition = flag.Int("partition", 1, "partition number, this is 1-indexed")
+ totalPartitions = flag.Int("total_partitions", 1, "total number of partitions")
+ isRunningWithHostNet = flag.Bool("hostnet", false, "whether test is running with hostnet")
)
// IsCheckpointSupported returns the relevant command line flag.
@@ -57,6 +59,11 @@ func IsCheckpointSupported() bool {
return *checkpoint
}
+// IsRunningWithHostNet returns the relevant command line flag.
+func IsRunningWithHostNet() bool {
+ return *isRunningWithHostNet
+}
+
// ImageByName mangles the image name used locally. This depends on the image
// build infrastructure in images/ and tools/vm.
func ImageByName(name string) string {
@@ -249,14 +256,25 @@ func writeSpec(dir string, spec *specs.Spec) error {
// idRandomSrc is a pseudo random generator used to in RandomID.
var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
+// idRandomSrcMtx is the mutex protecting idRandomSrc.Read from being used
+// concurrently in differnt goroutines.
+var idRandomSrcMtx sync.Mutex
+
// RandomID returns 20 random bytes following the given prefix.
func RandomID(prefix string) string {
// Read 20 random bytes.
b := make([]byte, 20)
+ // Rand.Read is not safe for concurrent use. Packetimpact tests can be run in
+ // parallel now, so we have to protect the Read with a mutex. Otherwise we'll
+ // run into name conflicts.
+ // https://golang.org/pkg/math/rand/#Rand.Read
+ idRandomSrcMtx.Lock()
// "[Read] always returns len(p) and a nil error." --godoc
if _, err := idRandomSrc.Read(b); err != nil {
+ idRandomSrcMtx.Unlock()
panic("rand.Read failed: " + err.Error())
}
+ idRandomSrcMtx.Unlock()
if prefix != "" {
prefix = prefix + "-"
}
@@ -417,33 +435,35 @@ func StartReaper() func() {
// WaitUntilRead reads from the given reader until the wanted string is found
// or until timeout.
-func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error {
+func WaitUntilRead(r io.Reader, want string, timeout time.Duration) error {
sc := bufio.NewScanner(r)
- if split != nil {
- sc.Split(split)
- }
// done must be accessed atomically. A value greater than 0 indicates
// that the read loop can exit.
- var done uint32
- doneCh := make(chan struct{})
+ doneCh := make(chan bool)
+ defer close(doneCh)
go func() {
for sc.Scan() {
t := sc.Text()
if strings.Contains(t, want) {
- atomic.StoreUint32(&done, 1)
- close(doneCh)
- break
+ doneCh <- true
+ return
}
- if atomic.LoadUint32(&done) > 0 {
- break
+ select {
+ case <-doneCh:
+ return
+ default:
}
}
+ doneCh <- false
}()
+
select {
case <-time.After(timeout):
- atomic.StoreUint32(&done, 1)
return fmt.Errorf("timeout waiting to read %q", want)
- case <-doneCh:
+ case res := <-doneCh:
+ if !res {
+ return fmt.Errorf("reader closed while waiting to read %q", want)
+ }
return nil
}
}
@@ -509,7 +529,8 @@ func TouchShardStatusFile() error {
}
// TestIndicesForShard returns indices for this test shard based on the
-// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
+// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars, as well as
+// the passed partition flags.
//
// If either of the env vars are not present, then the function will return all
// tests. If there are more shards than there are tests, then the returned list
@@ -534,6 +555,11 @@ func TestIndicesForShard(numTests int) ([]int, error) {
}
}
+ // Combine with the partitions.
+ partitionSize := shardTotal
+ shardTotal = (*totalPartitions) * shardTotal
+ shardIndex = partitionSize*(*partition-1) + shardIndex
+
// Calculate!
var indices []int
numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 9b1e7a085..79db8895b 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -167,7 +167,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) {
return n, err
}
-// Writer implements io.Writer.Write.
+// Write implements io.Writer.Write.
func (rw *IOReadWriter) Write(src []byte) (int, error) {
n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts)
end, ok := rw.Addr.AddLength(uint64(n))
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 08519d986..83d4f893a 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -119,7 +119,10 @@ type EntryCallback interface {
// The callback is supposed to perform minimal work, and cannot call
// any method on the queue itself because it will be locked while the
// callback is running.
- Callback(e *Entry)
+ //
+ // The mask indicates the events that occurred and that the entry is
+ // interested in.
+ Callback(e *Entry, mask EventMask)
}
// Entry represents a waiter that can be add to the a wait queue. It can
@@ -140,7 +143,7 @@ type channelCallback struct {
}
// Callback implements EntryCallback.Callback.
-func (c *channelCallback) Callback(*Entry) {
+func (c *channelCallback) Callback(*Entry, EventMask) {
select {
case c.ch <- struct{}{}:
default:
@@ -193,8 +196,8 @@ func (q *Queue) EventUnregister(e *Entry) {
func (q *Queue) Notify(mask EventMask) {
q.mu.RLock()
for e := q.list.Front(); e != nil; e = e.Next() {
- if mask&e.mask != 0 {
- e.Callback.Callback(e)
+ if m := mask & e.mask; m != 0 {
+ e.Callback.Callback(e, m)
}
}
q.mu.RUnlock()
diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go
index c1b94a4f3..6928f28b4 100644
--- a/pkg/waiter/waiter_test.go
+++ b/pkg/waiter/waiter_test.go
@@ -20,12 +20,12 @@ import (
)
type callbackStub struct {
- f func(e *Entry)
+ f func(e *Entry, m EventMask)
}
// Callback implements EntryCallback.Callback.
-func (c *callbackStub) Callback(e *Entry) {
- c.f(e)
+func (c *callbackStub) Callback(e *Entry, m EventMask) {
+ c.f(e, m)
}
func TestEmptyQueue(t *testing.T) {
@@ -36,7 +36,7 @@ func TestEmptyQueue(t *testing.T) {
// Register then unregister a waiter, then notify the queue.
cnt := 0
- e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}}
q.EventRegister(&e, EventIn)
q.EventUnregister(&e)
q.Notify(EventIn)
@@ -49,7 +49,7 @@ func TestMask(t *testing.T) {
// Register a waiter.
var q Queue
var cnt int
- e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}}
q.EventRegister(&e, EventIn|EventErr)
// Notify with an overlapping mask.
@@ -101,11 +101,14 @@ func TestConcurrentRegistration(t *testing.T) {
for i := 0; i < concurrency; i++ {
go func() {
var e Entry
- e.Callback = &callbackStub{func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry, mask EventMask) {
cnt++
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
+ if mask != EventIn {
+ t.Errorf("mask = %#x want %#x", mask, EventIn)
+ }
}}
// Wait for notification, then register.
@@ -158,11 +161,14 @@ func TestConcurrentNotification(t *testing.T) {
// Register waiters.
for i := 0; i < waiterCount; i++ {
var e Entry
- e.Callback = &callbackStub{func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry, mask EventMask) {
atomic.AddInt32(&cnt, 1)
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
+ if mask != EventIn {
+ t.Errorf("mask = %#x want %#x", mask, EventIn)
+ }
}}
q.EventRegister(&e, EventIn|EventErr)
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 7076ae2e2..a3a76b609 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -53,7 +53,7 @@ type compatEmitter struct {
func newCompatEmitter(logFD int) (*compatEmitter, error) {
nameMap, ok := getSyscallNameMap()
if !ok {
- return nil, fmt.Errorf("Linux syscall table not found")
+ return nil, fmt.Errorf("syscall table not found")
}
c := &compatEmitter{
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index fdf13c8e1..865126ac5 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -211,10 +211,31 @@ func (cm *containerManager) Processes(cid *string, out *[]*control.Process) erro
return control.Processes(cm.l.k, *cid, out)
}
+// CreateArgs contains arguments to the Create method.
+type CreateArgs struct {
+ // CID is the ID of the container to start.
+ CID string
+
+ // FilePayload may contain a TTY file for the terminal, if enabled.
+ urpc.FilePayload
+}
+
// Create creates a container within a sandbox.
-func (cm *containerManager) Create(cid *string, _ *struct{}) error {
- log.Debugf("containerManager.Create, cid: %s", *cid)
- return cm.l.createContainer(*cid)
+func (cm *containerManager) Create(args *CreateArgs, _ *struct{}) error {
+ log.Debugf("containerManager.Create: %s", args.CID)
+
+ if len(args.Files) > 1 {
+ return fmt.Errorf("start arguments must have at most 1 files for TTY")
+ }
+ var tty *fd.FD
+ if len(args.Files) == 1 {
+ var err error
+ tty, err = fd.NewFromFile(args.Files[0])
+ if err != nil {
+ return fmt.Errorf("error dup'ing TTY file: %w", err)
+ }
+ }
+ return cm.l.createContainer(args.CID, tty)
}
// StartArgs contains arguments to the Start method.
@@ -229,9 +250,8 @@ type StartArgs struct {
CID string
// FilePayload contains, in order:
- // * stdin, stdout, and stderr.
- // * the file descriptor over which the sandbox will
- // request files from its root filesystem.
+ // * stdin, stdout, and stderr (optional: if terminal is disabled).
+ // * file descriptors to connect to gofer to serve the root filesystem.
urpc.FilePayload
}
@@ -251,23 +271,45 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error {
if args.CID == "" {
return errors.New("start argument missing container ID")
}
- if len(args.FilePayload.Files) < 4 {
- return fmt.Errorf("start arguments must contain stdin, stderr, and stdout followed by at least one file for the container root gofer")
+ if len(args.Files) < 1 {
+ return fmt.Errorf("start arguments must contain at least one file for the container root gofer")
}
// All validation passed, logs the spec for debugging.
specutils.LogSpec(args.Spec)
- fds, err := fd.NewFromFiles(args.FilePayload.Files)
+ goferFiles := args.Files
+ var stdios []*fd.FD
+ if !args.Spec.Process.Terminal {
+ // When not using a terminal, stdios come as the first 3 files in the
+ // payload.
+ if l := len(args.Files); l < 4 {
+ return fmt.Errorf("start arguments (len: %d) must contain stdios and files for the container root gofer", l)
+ }
+ var err error
+ stdios, err = fd.NewFromFiles(goferFiles[:3])
+ if err != nil {
+ return fmt.Errorf("error dup'ing stdio files: %w", err)
+ }
+ goferFiles = goferFiles[3:]
+ }
+ defer func() {
+ for _, fd := range stdios {
+ _ = fd.Close()
+ }
+ }()
+
+ goferFDs, err := fd.NewFromFiles(goferFiles)
if err != nil {
- return err
+ return fmt.Errorf("error dup'ing gofer files: %w", err)
}
defer func() {
- for _, fd := range fds {
+ for _, fd := range goferFDs {
_ = fd.Close()
}
}()
- if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, fds); err != nil {
+
+ if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, stdios, goferFDs); err != nil {
log.Debugf("containerManager.Start failed, cid: %s, args: %+v, err: %v", args.CID, args, err)
return err
}
@@ -330,18 +372,18 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
log.Debugf("containerManager.Restore")
var specFile, deviceFile *os.File
- switch numFiles := len(o.FilePayload.Files); numFiles {
+ switch numFiles := len(o.Files); numFiles {
case 2:
// The device file is donated to the platform.
// Can't take ownership away from os.File. dup them to get a new FD.
- fd, err := syscall.Dup(int(o.FilePayload.Files[1].Fd()))
+ fd, err := syscall.Dup(int(o.Files[1].Fd()))
if err != nil {
return fmt.Errorf("failed to dup file: %v", err)
}
deviceFile = os.NewFile(uintptr(fd), "platform device")
fallthrough
case 1:
- specFile = o.FilePayload.Files[0]
+ specFile = o.Files[0]
case 0:
return fmt.Errorf("at least one file must be passed to Restore")
default:
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index a7c4ebb0c..4e3bb9ac7 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -343,6 +343,16 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_PKTINFO),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_IPV6),
seccomp.EqualTo(syscall.IPV6_TCLASS),
},
@@ -358,6 +368,11 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_SOCKET),
seccomp.EqualTo(syscall.SO_ERROR),
},
@@ -393,6 +408,11 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_TIMESTAMP),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_TCP),
seccomp.EqualTo(syscall.TCP_NODELAY),
},
@@ -401,6 +421,11 @@ func hostInetFilters() seccomp.SyscallRules {
seccomp.EqualTo(syscall.SOL_TCP),
seccomp.EqualTo(syscall.TCP_INFO),
},
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(linux.TCP_INQ),
+ },
},
syscall.SYS_IOCTL: []seccomp.Rule{
{
@@ -449,6 +474,13 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_TIMESTAMP),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_TCP),
seccomp.EqualTo(syscall.TCP_NODELAY),
seccomp.MatchAny{},
@@ -456,6 +488,13 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(linux.TCP_INQ),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_IP),
seccomp.EqualTo(syscall.IP_TOS),
seccomp.MatchAny{},
@@ -470,6 +509,20 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_PKTINFO),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_IPV6),
seccomp.EqualTo(syscall.IPV6_TCLASS),
seccomp.MatchAny{},
@@ -482,6 +535,13 @@ func hostInetFilters() seccomp.SyscallRules {
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
},
syscall.SYS_SHUTDOWN: []seccomp.Rule{
{
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 6b6ae98d7..2b0d2cd51 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -22,15 +22,6 @@ import (
"strings"
"syscall"
- // Include filesystem types that OCI spec might mount.
- _ "gvisor.dev/gvisor/pkg/sentry/fs/dev"
- _ "gvisor.dev/gvisor/pkg/sentry/fs/host"
- _ "gvisor.dev/gvisor/pkg/sentry/fs/proc"
- _ "gvisor.dev/gvisor/pkg/sentry/fs/sys"
- _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
- _ "gvisor.dev/gvisor/pkg/sentry/fs/tty"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -48,9 +39,18 @@ import (
tmpfsvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/specutils"
+
+ // Include filesystem types that OCI spec might mount.
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/dev"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/proc"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/sys"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ _ "gvisor.dev/gvisor/pkg/sentry/fs/tty"
)
const (
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index ebdd518d0..3df013d34 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -157,6 +157,11 @@ type execProcess struct {
// pidnsPath is the pid namespace path in spec
pidnsPath string
+
+ // hostTTY is present when creating a sub-container with terminal enabled.
+ // TTY file is passed during container create and must be saved until
+ // container start.
+ hostTTY *fd.FD
}
func init() {
@@ -588,7 +593,9 @@ func (l *Loader) run() error {
// Create the root container init task. It will begin running
// when the kernel is started.
- if _, err := l.createContainerProcess(true, l.sandboxID, &l.root, ep); err != nil {
+ var err error
+ _, ep.tty, ep.ttyVFS2, err = l.createContainerProcess(true, l.sandboxID, &l.root)
+ if err != nil {
return err
}
@@ -627,7 +634,7 @@ func (l *Loader) run() error {
}
// createContainer creates a new container inside the sandbox.
-func (l *Loader) createContainer(cid string) error {
+func (l *Loader) createContainer(cid string, tty *fd.FD) error {
l.mu.Lock()
defer l.mu.Unlock()
@@ -635,14 +642,14 @@ func (l *Loader) createContainer(cid string) error {
if _, ok := l.processes[eid]; ok {
return fmt.Errorf("container %q already exists", cid)
}
- l.processes[eid] = &execProcess{}
+ l.processes[eid] = &execProcess{hostTTY: tty}
return nil
}
// startContainer starts a child container. It returns the thread group ID of
// the newly created process. Used FDs are either closed or released. It's safe
// for the caller to close any remaining files upon return.
-func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid string, files []*fd.FD) error {
+func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid string, stdioFDs, goferFDs []*fd.FD) error {
// Create capabilities.
caps, err := specutils.Capabilities(conf.EnableRaw, spec.Process.Capabilities)
if err != nil {
@@ -695,36 +702,41 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid strin
info := &containerInfo{
conf: conf,
spec: spec,
- stdioFDs: files[:3],
- goferFDs: files[3:],
+ goferFDs: goferFDs,
}
info.procArgs, err = createProcessArgs(cid, spec, creds, l.k, pidns)
if err != nil {
return fmt.Errorf("creating new process: %v", err)
}
- tg, err := l.createContainerProcess(false, cid, info, ep)
+
+ // Use stdios or TTY depending on the spec configuration.
+ if spec.Process.Terminal {
+ if len(stdioFDs) > 0 {
+ return fmt.Errorf("using TTY, stdios not expected: %v", stdioFDs)
+ }
+ if ep.hostTTY == nil {
+ return fmt.Errorf("terminal enabled but no TTY provided. Did you set --console-socket on create?")
+ }
+ info.stdioFDs = []*fd.FD{ep.hostTTY, ep.hostTTY, ep.hostTTY}
+ ep.hostTTY = nil
+ } else {
+ info.stdioFDs = stdioFDs
+ }
+
+ ep.tg, ep.tty, ep.ttyVFS2, err = l.createContainerProcess(false, cid, info)
if err != nil {
return err
}
-
- // Success!
- l.k.StartProcess(tg)
- ep.tg = tg
+ l.k.StartProcess(ep.tg)
return nil
}
-func (l *Loader) createContainerProcess(root bool, cid string, info *containerInfo, ep *execProcess) (*kernel.ThreadGroup, error) {
- console := false
- if root {
- // Only root container supports terminal for now.
- console = info.spec.Process.Terminal
- }
-
+func (l *Loader) createContainerProcess(root bool, cid string, info *containerInfo) (*kernel.ThreadGroup, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
// Create the FD map, which will set stdin, stdout, and stderr.
ctx := info.procArgs.NewContext(l.k)
- fdTable, ttyFile, ttyFileVFS2, err := createFDTable(ctx, console, info.stdioFDs)
+ fdTable, ttyFile, ttyFileVFS2, err := createFDTable(ctx, info.spec.Process.Terminal, info.stdioFDs)
if err != nil {
- return nil, fmt.Errorf("importing fds: %v", err)
+ return nil, nil, nil, fmt.Errorf("importing fds: %v", err)
}
// CreateProcess takes a reference on fdTable if successful. We won't need
// ours either way.
@@ -736,11 +748,11 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints)
if root {
if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil {
- return nil, err
+ return nil, nil, nil, err
}
}
if err := setupContainerFS(ctx, info.conf, mntr, &info.procArgs); err != nil {
- return nil, err
+ return nil, nil, nil, err
}
// Add the HOME environment variable if it is not already set.
@@ -754,29 +766,25 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
info.procArgs.Credentials.RealKUID, info.procArgs.Envv)
}
if err != nil {
- return nil, err
+ return nil, nil, nil, err
}
info.procArgs.Envv = envv
// Create and start the new process.
tg, _, err := l.k.CreateProcess(info.procArgs)
if err != nil {
- return nil, fmt.Errorf("creating process: %v", err)
+ return nil, nil, nil, fmt.Errorf("creating process: %v", err)
}
// CreateProcess takes a reference on FDTable if successful.
info.procArgs.FDTable.DecRef(ctx)
// Set the foreground process group on the TTY to the global init process
// group, since that is what we are about to start running.
- if root {
- switch {
- case ttyFileVFS2 != nil:
- ep.ttyVFS2 = ttyFileVFS2
- ttyFileVFS2.InitForegroundProcessGroup(tg.ProcessGroup())
- case ttyFile != nil:
- ep.tty = ttyFile
- ttyFile.InitForegroundProcessGroup(tg.ProcessGroup())
- }
+ switch {
+ case ttyFileVFS2 != nil:
+ ttyFileVFS2.InitForegroundProcessGroup(tg.ProcessGroup())
+ case ttyFile != nil:
+ ttyFile.InitForegroundProcessGroup(tg.ProcessGroup())
}
// Install seccomp filters with the new task if there are any.
@@ -784,7 +792,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
if info.spec.Linux != nil && info.spec.Linux.Seccomp != nil {
program, err := seccomp.BuildProgram(info.spec.Linux.Seccomp)
if err != nil {
- return nil, fmt.Errorf("building seccomp program: %v", err)
+ return nil, nil, nil, fmt.Errorf("building seccomp program: %v", err)
}
if log.IsLogging(log.Debug) {
@@ -795,7 +803,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
task := tg.Leader()
// NOTE: It seems Flags are ignored by runc so we ignore them too.
if err := task.AppendSyscallFilter(program, true); err != nil {
- return nil, fmt.Errorf("appending seccomp filters: %v", err)
+ return nil, nil, nil, fmt.Errorf("appending seccomp filters: %v", err)
}
}
} else {
@@ -804,7 +812,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
}
}
- return tg, nil
+ return tg, ttyFile, ttyFileVFS2, nil
}
// startGoferMonitor runs a goroutine to monitor gofer's health. It polls on
@@ -1074,7 +1082,12 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st
func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}
- transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4}
+ transProtos := []stack.TransportProtocolFactory{
+ tcp.NewProtocol,
+ udp.NewProtocol,
+ icmp.NewProtocol4,
+ icmp.NewProtocol6,
+ }
s := netstack.Stack{stack.New(stack.Options{
NetworkProtocols: netProtos,
TransportProtocols: transProtos,
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index f58b09942..3d3a813df 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -40,9 +40,9 @@ var (
// "::1/8" on "lo" interface.
DefaultLoopbackLink = LoopbackLink{
Name: "lo",
- Addresses: []net.IP{
- net.IP("\x7f\x00\x00\x01"),
- net.IPv6loopback,
+ Addresses: []IPWithPrefix{
+ {Address: net.IP("\x7f\x00\x00\x01"), PrefixLen: 8},
+ {Address: net.IPv6loopback, PrefixLen: 128},
},
Routes: []Route{
{
@@ -82,7 +82,7 @@ type DefaultRoute struct {
type FDBasedLink struct {
Name string
MTU int
- Addresses []net.IP
+ Addresses []IPWithPrefix
Routes []Route
GSOMaxSize uint32
SoftwareGSOEnabled bool
@@ -99,7 +99,7 @@ type FDBasedLink struct {
// LoopbackLink configures a loopback li nk.
type LoopbackLink struct {
Name string
- Addresses []net.IP
+ Addresses []IPWithPrefix
Routes []Route
}
@@ -117,6 +117,19 @@ type CreateLinksAndRoutesArgs struct {
Defaultv6Gateway DefaultRoute
}
+// IPWithPrefix is an address with its subnet prefix length.
+type IPWithPrefix struct {
+ // Address is a network address.
+ Address net.IP
+
+ // PrefixLen is the subnet prefix length.
+ PrefixLen int
+}
+
+func (ip IPWithPrefix) String() string {
+ return fmt.Sprintf("%s/%d", ip.Address, ip.PrefixLen)
+}
+
// Empty returns true if route hasn't been set.
func (r *Route) Empty() bool {
return r.Destination.IP == nil && r.Destination.Mask == nil && r.Gateway == nil
@@ -264,15 +277,19 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
// createNICWithAddrs creates a NIC in the network stack and adds the given
// addresses.
-func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP) error {
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []IPWithPrefix) error {
opts := stack.NICOptions{Name: name}
if err := n.Stack.CreateNICWithOptions(id, sniffer.New(ep), opts); err != nil {
return fmt.Errorf("CreateNICWithOptions(%d, _, %+v) failed: %v", id, opts, err)
}
for _, addr := range addrs {
- proto, tcpipAddr := ipToAddressAndProto(addr)
- if err := n.Stack.AddAddress(id, proto, tcpipAddr); err != nil {
+ proto, tcpipAddr := ipToAddressAndProto(addr.Address)
+ ap := tcpip.AddressWithPrefix{
+ Address: tcpipAddr,
+ PrefixLen: addr.PrefixLen,
+ }
+ if err := n.Stack.AddAddressWithPrefix(id, proto, ap); err != nil {
return fmt.Errorf("AddAddress(%v, %v, %v) failed: %v", id, proto, tcpipAddr, err)
}
}
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index b157387ef..3fd28e516 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -250,36 +250,76 @@ func (c *containerMounter) configureOverlay(ctx context.Context, creds *auth.Cre
overlayOpts := *lowerOpts
overlayOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{}
- // Next mount upper and lower. Upper is a tmpfs mount to keep all
- // modifications inside the sandbox.
- upper, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, tmpfs.Name, &upperOpts)
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create upper layer for overlay, opts: %+v: %v", upperOpts, err)
- }
- cu := cleanup.Make(func() { upper.DecRef(ctx) })
- defer cu.Clean()
-
// All writes go to the upper layer, be paranoid and make lower readonly.
lowerOpts.ReadOnly = true
lower, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, lowerFSName, lowerOpts)
if err != nil {
return nil, nil, err
}
- cu.Add(func() { lower.DecRef(ctx) })
+ cu := cleanup.Make(func() { lower.DecRef(ctx) })
+ defer cu.Clean()
- // Propagate the lower layer's root's owner, group, and mode to the upper
- // layer's root for consistency with VFS1.
- upperRootVD := vfs.MakeVirtualDentry(upper, upper.Root())
+ // Determine the lower layer's root's type.
lowerRootVD := vfs.MakeVirtualDentry(lower, lower.Root())
stat, err := c.k.VFS().StatAt(ctx, creds, &vfs.PathOperation{
Root: lowerRootVD,
Start: lowerRootVD,
}, &vfs.StatOptions{
- Mask: linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE,
+ Mask: linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE | linux.STATX_TYPE,
})
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to stat lower layer's root: %v", err)
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 {
+ return nil, nil, fmt.Errorf("failed to get file type of lower layer's root")
+ }
+ rootType := stat.Mode & linux.S_IFMT
+ if rootType != linux.S_IFDIR && rootType != linux.S_IFREG {
+ return nil, nil, fmt.Errorf("lower layer's root has unsupported file type %v", rootType)
+ }
+
+ // Upper is a tmpfs mount to keep all modifications inside the sandbox.
+ upperOpts.GetFilesystemOptions.InternalData = tmpfs.FilesystemOpts{
+ RootFileType: uint16(rootType),
+ }
+ upper, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, tmpfs.Name, &upperOpts)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create upper layer for overlay, opts: %+v: %v", upperOpts, err)
+ }
+ cu.Add(func() { upper.DecRef(ctx) })
+
+ // If the overlay mount consists of a regular file, copy up its contents
+ // from the lower layer, since in the overlay the otherwise-empty upper
+ // layer file will take precedence.
+ upperRootVD := vfs.MakeVirtualDentry(upper, upper.Root())
+ if rootType == linux.S_IFREG {
+ lowerFD, err := c.k.VFS().OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: lowerRootVD,
+ Start: lowerRootVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to open lower layer root for copying: %v", err)
+ }
+ defer lowerFD.DecRef(ctx)
+ upperFD, err := c.k.VFS().OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: upperRootVD,
+ Start: upperRootVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY,
+ })
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to open upper layer root for copying: %v", err)
+ }
+ defer upperFD.DecRef(ctx)
+ if _, err := vfs.CopyRegularFileData(ctx, upperFD, lowerFD); err != nil {
+ return nil, nil, fmt.Errorf("failed to copy up overlay file: %v", err)
+ }
}
+
+ // Propagate the lower layer's root's owner, group, and mode to the upper
+ // layer's root for consistency with VFS1.
err = c.k.VFS().SetStatAt(ctx, creds, &vfs.PathOperation{
Root: upperRootVD,
Start: upperRootVD,
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
index 5bd0afc52..e5294de55 100644
--- a/runsc/cgroup/cgroup.go
+++ b/runsc/cgroup/cgroup.go
@@ -234,7 +234,7 @@ func loadPathsHelper(cgroup io.Reader) (map[string]string, error) {
type Cgroup struct {
Name string `json:"name"`
Parents map[string]string `json:"parents"`
- Own bool `json:"own"`
+ Own map[string]bool `json:"own"`
}
// New creates a new Cgroup instance if the spec includes a cgroup path.
@@ -251,9 +251,11 @@ func New(spec *specs.Spec) (*Cgroup, error) {
return nil, fmt.Errorf("finding current cgroups: %w", err)
}
}
+ own := make(map[string]bool)
return &Cgroup{
Name: spec.Linux.CgroupsPath,
Parents: parents,
+ Own: own,
}, nil
}
@@ -261,18 +263,8 @@ func New(spec *specs.Spec) (*Cgroup, error) {
// already exists, it means that the caller has already provided a
// pre-configured cgroups, and 'res' is ignored.
func (c *Cgroup) Install(res *specs.LinuxResources) error {
- if _, err := os.Stat(c.makePath("memory")); err == nil {
- // If cgroup has already been created; it has been setup by caller. Don't
- // make any changes to configuration, just join when sandbox/gofer starts.
- log.Debugf("Using pre-created cgroup %q", c.Name)
- return nil
- }
-
log.Debugf("Creating cgroup %q", c.Name)
- // Mark that cgroup resources are owned by me.
- c.Own = true
-
// The Cleanup object cleans up partially created cgroups when an error occurs.
// Errors occuring during cleanup itself are ignored.
clean := cleanup.Make(func() { _ = c.Uninstall() })
@@ -280,6 +272,16 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error {
for key, cfg := range controllers {
path := c.makePath(key)
+ if _, err := os.Stat(path); err == nil {
+ // If cgroup has already been created; it has been setup by caller. Don't
+ // make any changes to configuration, just join when sandbox/gofer starts.
+ log.Debugf("Using pre-created cgroup %q", path)
+ continue
+ }
+
+ // Mark that cgroup resources are owned by me.
+ c.Own[key] = true
+
if err := os.MkdirAll(path, 0755); err != nil {
if cfg.optional && errors.Is(err, syscall.EROFS) {
log.Infof("Skipping cgroup %q", key)
@@ -298,12 +300,12 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error {
// Uninstall removes the settings done in Install(). If cgroup path already
// existed when Install() was called, Uninstall is a noop.
func (c *Cgroup) Uninstall() error {
- if !c.Own {
- // cgroup is managed by caller, don't touch it.
- return nil
- }
log.Debugf("Deleting cgroup %q", c.Name)
for key := range controllers {
+ if !c.Own[key] {
+ // cgroup is managed by caller, don't touch it.
+ continue
+ }
path := c.makePath(key)
log.Debugf("Removing cgroup controller for key=%q path=%q", key, path)
diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go
index 9794517a7..931144cf9 100644
--- a/runsc/cgroup/cgroup_test.go
+++ b/runsc/cgroup/cgroup_test.go
@@ -29,7 +29,10 @@ func TestUninstallEnoent(t *testing.T) {
c := Cgroup{
// set a non-existent name
Name: "runsc-test-uninstall-656e6f656e740a",
- Own: true,
+ }
+ c.Own = make(map[string]bool)
+ for key := range controllers {
+ c.Own[key] = true
}
if err := c.Uninstall(); err != nil {
t.Errorf("Uninstall() failed: %v", err)
diff --git a/runsc/cli/main.go b/runsc/cli/main.go
index bca015db5..6c3bf4d21 100644
--- a/runsc/cli/main.go
+++ b/runsc/cli/main.go
@@ -22,6 +22,7 @@ import (
"io/ioutil"
"os"
"os/signal"
+ "runtime"
"syscall"
"time"
@@ -82,6 +83,7 @@ func Main(version string) {
subcommands.Register(new(cmd.Spec), "")
subcommands.Register(new(cmd.State), "")
subcommands.Register(new(cmd.Start), "")
+ subcommands.Register(new(cmd.Symbolize), "")
subcommands.Register(new(cmd.Wait), "")
// Register internal commands with the internal group name. This causes
@@ -207,6 +209,8 @@ func Main(version string) {
log.Infof("***************************")
log.Infof("Args: %s", os.Args)
log.Infof("Version %s", version)
+ log.Infof("GOOS: %s", runtime.GOOS)
+ log.Infof("GOARCH: %s", runtime.GOARCH)
log.Infof("PID: %d", os.Getpid())
log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid())
log.Infof("Configuration:")
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 2556f6d9e..19520d7ab 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -32,6 +32,7 @@ go_library(
"start.go",
"state.go",
"statefile.go",
+ "symbolize.go",
"syscalls.go",
"wait.go",
],
@@ -39,6 +40,7 @@ go_library(
"//runsc:__subpackages__",
],
deps = [
+ "//pkg/coverage",
"//pkg/log",
"//pkg/p9",
"//pkg/sentry/control",
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 640de4c47..8a8d9f752 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -81,7 +81,7 @@ func (c *Do) SetFlags(f *flag.FlagSet) {
// Execute implements subcommands.Command.Execute.
func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
if len(f.Args()) == 0 {
- c.Usage()
+ f.Usage()
return subcommands.ExitUsageError
}
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index 86c02a22a..eafd6285c 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -150,7 +150,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *syscall.WaitStatus) subcommands.ExitStatus {
- // Start the new process and get it pid.
+ // Start the new process and get its pid.
pid, err := c.Execute(e)
if err != nil {
return Errorf("executing processes for container: %v", err)
diff --git a/runsc/cmd/symbolize.go b/runsc/cmd/symbolize.go
new file mode 100644
index 000000000..fc0c69358
--- /dev/null
+++ b/runsc/cmd/symbolize.go
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "bufio"
+ "context"
+ "os"
+ "strconv"
+ "strings"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/coverage"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Symbolize implements subcommands.Command for the "symbolize" command.
+type Symbolize struct {
+ dumpAll bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Symbolize) Name() string {
+ return "symbolize"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Symbolize) Synopsis() string {
+ return "Convert synthetic instruction pointers from kcov into positions in the runsc source code. Only used when Go coverage is enabled."
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Symbolize) Usage() string {
+ return `symbolize - converts synthetic instruction pointers into positions in the runsc source code.
+
+This command takes instruction pointers from stdin and converts them into their
+corresponding file names and line/column numbers in the runsc source code. The
+inputs are not interpreted as actual addresses, but as synthetic values that are
+exposed through /sys/kernel/debug/kcov. One can extract coverage information
+from kcov and translate those values into locations in the source code by
+running symbolize on the same runsc binary.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *Symbolize) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&c.dumpAll, "all", false, "dump information on all coverage blocks along with their synthetic PCs")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *Symbolize) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 0 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ if !coverage.KcovAvailable() {
+ return Errorf("symbolize can only be used when coverage is available.")
+ }
+ coverage.InitCoverageData()
+
+ if c.dumpAll {
+ coverage.WriteAllBlocks(os.Stdout)
+ return subcommands.ExitSuccess
+ }
+
+ scanner := bufio.NewScanner(os.Stdin)
+ for scanner.Scan() {
+ // Input is always base 16, but may or may not have a leading "0x".
+ str := strings.TrimPrefix(scanner.Text(), "0x")
+ pc, err := strconv.ParseUint(str, 16 /* base */, 64 /* bitSize */)
+ if err != nil {
+ return Errorf("Failed to symbolize \"%s\": %v", scanner.Text(), err)
+ }
+ if err := coverage.Symbolize(os.Stdout, pc); err != nil {
+ return Errorf("Failed to symbolize \"%s\": %v", scanner.Text(), err)
+ }
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/console/console.go b/runsc/console/console.go
index dbb88e117..b36028792 100644
--- a/runsc/console/console.go
+++ b/runsc/console/console.go
@@ -24,8 +24,8 @@ import (
"golang.org/x/sys/unix"
)
-// NewWithSocket creates pty master/replica pair, sends the master FD over the given
-// socket, and returns the replica.
+// NewWithSocket creates pty master/replica pair, sends the master FD over the
+// given socket, and returns the replica.
func NewWithSocket(socketPath string) (*os.File, error) {
// Create a new pty master and replica.
ptyMaster, ptyReplica, err := pty.Open()
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index c33755482..8793c8916 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
package(licenses = ["notice"])
@@ -24,6 +24,7 @@ go_library(
"//runsc/boot",
"//runsc/cgroup",
"//runsc/config",
+ "//runsc/console",
"//runsc/sandbox",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
@@ -48,7 +49,7 @@ go_test(
"//test/cmd/test_app",
],
library = ":container",
- shard_count = 10,
+ shard_count = more_shards,
tags = [
"requires-kvm",
],
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 4228399b8..1b0fdebd6 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"io"
+ "math/rand"
"os"
"path/filepath"
"syscall"
@@ -27,7 +28,6 @@ import (
"github.com/kr/pty"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/sentry/control"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/pkg/unet"
@@ -38,19 +38,22 @@ import (
// path is under 108 charactors (the unix socket path length limit),
// relativizing the path if necessary.
func socketPath(bundleDir string) (string, error) {
- path := filepath.Join(bundleDir, "socket")
+ num := rand.Intn(10000)
+ path := filepath.Join(bundleDir, fmt.Sprintf("socket-%4d", num))
+ const maxPathLen = 108
+ if len(path) <= maxPathLen {
+ return path, nil
+ }
+
+ // Path is too large, try to make it smaller.
cwd, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("error getting cwd: %v", err)
}
- relPath, err := filepath.Rel(cwd, path)
+ path, err = filepath.Rel(cwd, path)
if err != nil {
return "", fmt.Errorf("error getting relative path for %q from cwd %q: %v", path, cwd, err)
}
- if len(path) > len(relPath) {
- path = relPath
- }
- const maxPathLen = 108
if len(path) > maxPathLen {
return "", fmt.Errorf("could not get socket path under length limit %d: %s", maxPathLen, path)
}
@@ -159,6 +162,82 @@ func TestConsoleSocket(t *testing.T) {
}
}
+// Test that an pty FD is sent over the console socket if one is provided.
+func TestMultiContainerConsoleSocket(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ tru := []string{"true"}
+ testSpecs, ids := createSpecs(sleep, tru)
+ testSpecs[1].Process.Terminal = true
+
+ bundleDir, cleanup, err := testutil.SetupBundleDir(testSpecs[0])
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: ids[0],
+ Spec: testSpecs[0],
+ BundleDir: bundleDir,
+ }
+ rootCont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer rootCont.Destroy()
+ if err := rootCont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ bundleDir, cleanup, err = testutil.SetupBundleDir(testSpecs[0])
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ sock, err := socketPath(bundleDir)
+ if err != nil {
+ t.Fatalf("error getting socket path: %v", err)
+ }
+ srv, cleanup := createConsoleSocket(t, sock)
+ defer cleanup()
+
+ // Create the container and pass the socket name.
+ args = Args{
+ ID: ids[1],
+ Spec: testSpecs[1],
+ BundleDir: bundleDir,
+ ConsoleSocket: sock,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Make sure we get a console PTY.
+ ptyMaster, err := receiveConsolePTY(srv)
+ if err != nil {
+ t.Fatalf("error receiving console FD: %v", err)
+ }
+ ptyMaster.Close()
+ })
+ }
+}
+
// Test that job control signals work on a console created with "exec -ti".
func TestJobControlSignalExec(t *testing.T) {
spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
@@ -221,9 +300,9 @@ func TestJobControlSignalExec(t *testing.T) {
// Make sure all the processes are running.
expectedPL := []*control.Process{
// Root container process.
- {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().Cmd("sleep").Process(),
// Bash from exec process.
- {PID: 2, Cmd: "bash", Threads: []kernel.ThreadID{2}},
+ newProcessBuilder().PID(2).Cmd("bash").Process(),
}
if err := waitForProcessList(c, expectedPL); err != nil {
t.Error(err)
@@ -233,7 +312,7 @@ func TestJobControlSignalExec(t *testing.T) {
ptyMaster.Write([]byte("sleep 100\n"))
// Wait for it to start. Sleep's PPID is bash's PID.
- expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}})
+ expectedPL = append(expectedPL, newProcessBuilder().PID(3).PPID(2).Cmd("sleep").Process())
if err := waitForProcessList(c, expectedPL); err != nil {
t.Error(err)
}
@@ -254,7 +333,7 @@ func TestJobControlSignalExec(t *testing.T) {
// Sleep is dead, but it may take more time for bash to notice and
// change the foreground process back to itself. We know it is done
// when bash writes "Terminated" to the pty.
- if err := testutil.WaitUntilRead(ptyMaster, "Terminated", nil, 5*time.Second); err != nil {
+ if err := testutil.WaitUntilRead(ptyMaster, "Terminated", 5*time.Second); err != nil {
t.Fatalf("bash did not take over pty: %v", err)
}
@@ -359,7 +438,7 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// Wait for bash to start.
expectedPL := []*control.Process{
- {PID: 1, Cmd: "bash", Threads: []kernel.ThreadID{1}},
+ newProcessBuilder().PID(1).Cmd("bash").Process(),
}
if err := waitForProcessList(c, expectedPL); err != nil {
t.Fatalf("error waiting for processes: %v", err)
@@ -369,7 +448,7 @@ func TestJobControlSignalRootContainer(t *testing.T) {
ptyMaster.Write([]byte("sleep 100\n"))
// Wait for sleep to start.
- expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{2}})
+ expectedPL = append(expectedPL, newProcessBuilder().PID(2).PPID(1).Cmd("sleep").Process())
if err := waitForProcessList(c, expectedPL); err != nil {
t.Fatalf("error waiting for processes: %v", err)
}
@@ -393,7 +472,7 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// Sleep is dead, but it may take more time for bash to notice and
// change the foreground process back to itself. We know it is done
// when bash writes "Terminated" to the pty.
- if err := testutil.WaitUntilRead(ptyBuf, "Terminated", nil, 5*time.Second); err != nil {
+ if err := testutil.WaitUntilRead(ptyBuf, "Terminated", 5*time.Second); err != nil {
t.Fatalf("bash did not take over pty: %v", err)
}
@@ -414,6 +493,104 @@ func TestJobControlSignalRootContainer(t *testing.T) {
}
}
+// Test that terminal works with root and sub-containers.
+func TestMultiContainerTerminal(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Don't let bash execute from profile or rc files, otherwise our PID
+ // counts get messed up.
+ bash := []string{"/bin/bash", "--noprofile", "--norc"}
+ testSpecs, ids := createSpecs(bash, bash)
+
+ type termContainer struct {
+ container *Container
+ master *os.File
+ }
+ var containers []termContainer
+ for i, spec := range testSpecs {
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ spec.Process.Terminal = true
+ sock, err := socketPath(bundleDir)
+ if err != nil {
+ t.Fatalf("error getting socket path: %v", err)
+ }
+ srv, cleanup := createConsoleSocket(t, sock)
+ defer cleanup()
+
+ // Create the container and pass the socket name.
+ args := Args{
+ ID: ids[i],
+ Spec: spec,
+ BundleDir: bundleDir,
+ ConsoleSocket: sock,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Make sure we get a console PTY.
+ ptyMaster, err := receiveConsolePTY(srv)
+ if err != nil {
+ t.Fatalf("error receiving console FD: %v", err)
+ }
+ defer ptyMaster.Close()
+
+ containers = append(containers, termContainer{
+ container: cont,
+ master: ptyMaster,
+ })
+ }
+
+ for _, tc := range containers {
+ // Bash output as well as sandbox output will be written to the PTY
+ // file. Writes after a certain point will block unless we drain the
+ // PTY, so we must continually copy from it.
+ //
+ // We log the output to stderr for debugabilitly, and also to a buffer,
+ // since we wait on particular output from bash below. We use a custom
+ // blockingBuffer which is thread-safe and also blocks on Read calls,
+ // which makes this a suitable Reader for WaitUntilRead.
+ ptyBuf := newBlockingBuffer()
+ tee := io.TeeReader(tc.master, ptyBuf)
+ go io.Copy(os.Stderr, tee)
+
+ // Wait for bash to start.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("bash").Process(),
+ }
+ if err := waitForProcessList(tc.container, expectedPL); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ // Execute echo command and check that it was executed correctly. Use
+ // a variable to ensure it's not matching against command echo.
+ tc.master.Write([]byte("echo foo-${PWD}-123\n"))
+ if err := testutil.WaitUntilRead(ptyBuf, "foo-/-123", 5*time.Second); err != nil {
+ t.Fatalf("echo didn't execute: %v", err)
+ }
+ }
+ })
+ }
+}
+
// blockingBuffer is a thread-safe buffer that blocks when reading if the
// buffer is empty. It implements io.ReadWriter.
type blockingBuffer struct {
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 4aa139c88..418a27beb 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -38,6 +38,7 @@ import (
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/cgroup"
"gvisor.dev/gvisor/runsc/config"
+ "gvisor.dev/gvisor/runsc/console"
"gvisor.dev/gvisor/runsc/sandbox"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -79,6 +80,7 @@ func validateID(id string) error {
// - It calls 'runsc delete'. runc implementation kills --all SIGKILL once
// again just to be sure, waits, and then proceeds with remaining teardown.
//
+// Container is thread-unsafe.
type Container struct {
// ID is the container ID.
ID string `json:"id"`
@@ -397,7 +399,22 @@ func New(conf *config.Config, args Args) (*Container, error) {
return nil, err
}
c.Sandbox = sb.Sandbox
- if err := c.Sandbox.CreateContainer(c.ID); err != nil {
+
+ // If the console control socket file is provided, then create a new
+ // pty master/slave pair and send the TTY to the sandbox process.
+ var tty *os.File
+ if c.ConsoleSocket != "" {
+ // Create a new TTY pair and send the master on the provided socket.
+ var err error
+ tty, err = console.NewWithSocket(c.ConsoleSocket)
+ if err != nil {
+ return nil, fmt.Errorf("setting up console with socket %q: %w", c.ConsoleSocket, err)
+ }
+ // tty file is transferred to the sandbox, then it can be closed here.
+ defer tty.Close()
+ }
+
+ if err := c.Sandbox.CreateContainer(c.ID, tty); err != nil {
return nil, err
}
}
@@ -451,11 +468,16 @@ func (c *Container) Start(conf *config.Config) error {
// the start (and all their children processes).
if err := runInCgroup(c.Sandbox.Cgroup, func() error {
// Create the gofer process.
- ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir, false)
+ goferFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir, false)
if err != nil {
return err
}
- defer mountsFile.Close()
+ defer func() {
+ _ = mountsFile.Close()
+ for _, f := range goferFiles {
+ _ = f.Close()
+ }
+ }()
cleanMounts, err := specutils.ReadMounts(mountsFile)
if err != nil {
@@ -463,7 +485,14 @@ func (c *Container) Start(conf *config.Config) error {
}
c.Spec.Mounts = cleanMounts
- return c.Sandbox.StartContainer(c.Spec, conf, c.ID, ioFiles)
+ // Setup stdios if the container is not using terminal. Otherwise TTY was
+ // already setup in create.
+ var stdios []*os.File
+ if !c.Spec.Process.Terminal {
+ stdios = []*os.File{os.Stdin, os.Stdout, os.Stderr}
+ }
+
+ return c.Sandbox.StartContainer(c.Spec, conf, c.ID, stdios, goferFiles)
}); err != nil {
return err
}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index cadc63bf3..45d4e6e6e 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -301,54 +301,21 @@ func TestMultiContainerWait(t *testing.T) {
}
defer cleanup()
- // Check via ps that multiple processes are running.
- expectedPL := []*control.Process{
- newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
- }
- if err := waitForProcessList(containers[1], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
-
- // Wait on the short lived container from multiple goroutines.
- wg := sync.WaitGroup{}
- for i := 0; i < 3; i++ {
- wg.Add(1)
- go func(c *Container) {
- defer wg.Done()
- if ws, err := c.Wait(); err != nil {
- t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err)
- } else if es := ws.ExitStatus(); es != 0 {
- t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es)
- }
- if _, err := c.Wait(); err != nil {
- t.Errorf("wait for stopped container %s shouldn't fail: %v", c.Spec.Process.Args, err)
- }
- }(containers[1])
+ // Check that we can wait for the sub-container.
+ c := containers[1]
+ if ws, err := c.Wait(); err != nil {
+ t.Errorf("failed to wait for process %s: %v", c.Spec.Process.Args, err)
+ } else if es := ws.ExitStatus(); es != 0 {
+ t.Errorf("process %s exited with non-zero status %d", c.Spec.Process.Args, es)
}
-
- // Also wait via PID.
- for i := 0; i < 3; i++ {
- wg.Add(1)
- go func(c *Container) {
- defer wg.Done()
- const pid = 2
- if ws, err := c.WaitPID(pid); err != nil {
- t.Errorf("failed to wait for PID %d: %v", pid, err)
- } else if es := ws.ExitStatus(); es != 0 {
- t.Errorf("PID %d exited with non-zero status %d", pid, es)
- }
- if _, err := c.WaitPID(pid); err == nil {
- t.Errorf("wait for stopped PID %d should fail", pid)
- }
- }(containers[1])
+ if _, err := c.Wait(); err != nil {
+ t.Errorf("wait for stopped container %s shouldn't fail: %v", c.Spec.Process.Args, err)
}
- wg.Wait()
-
// After Wait returns, ensure that the root container is running and
// the child has finished.
- expectedPL = []*control.Process{
- newProcessBuilder().Cmd("sleep").Process(),
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").PID(1).Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Errorf("failed to wait for %q to start: %v", strings.Join(containers[0].Spec.Process.Args, " "), err)
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
index 96c57a426..c56e1d4d0 100644
--- a/runsc/fsgofer/BUILD
+++ b/runsc/fsgofer/BUILD
@@ -29,9 +29,12 @@ go_test(
srcs = ["fsgofer_test.go"],
library = ":fsgofer",
deps = [
+ "//pkg/fd",
"//pkg/log",
"//pkg/p9",
"//pkg/test/testutil",
+ "//runsc/specutils",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 0b628c8ce..3d94ffeb4 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -49,6 +49,21 @@ const (
allowedOpenFlags = unix.O_TRUNC
)
+var (
+ // Remember the process uid/gid to skip chown calls when file owner/group
+ // doesn't need to be changed.
+ processUID = p9.UID(os.Getuid())
+ processGID = p9.GID(os.Getgid())
+)
+
+// join is equivalent to path.Join() but skips path.Clean() which is expensive.
+func join(parent, child string) string {
+ if child == "." || child == ".." {
+ panic(fmt.Sprintf("invalid child path %q", child))
+ }
+ return parent + "/" + child
+}
+
// Config sets configuration options for each attach point.
type Config struct {
// ROMount is set to true if this is a readonly mount.
@@ -115,7 +130,7 @@ func (a *attachPoint) Attach() (p9.File, error) {
return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err)
}
- lf, err := newLocalFile(a, f, a.prefix, readable, stat)
+ lf, err := newLocalFile(a, f, a.prefix, readable, &stat)
if err != nil {
return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err)
}
@@ -124,7 +139,7 @@ func (a *attachPoint) Attach() (p9.File, error) {
}
// makeQID returns a unique QID for the given stat buffer.
-func (a *attachPoint) makeQID(stat unix.Stat_t) p9.QID {
+func (a *attachPoint) makeQID(stat *unix.Stat_t) p9.QID {
a.deviceMu.Lock()
defer a.deviceMu.Unlock()
@@ -245,7 +260,7 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) {
}
func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) {
- pathDebug := path.Join(parent.hostPath, name)
+ pathDebug := join(parent.hostPath, name)
f, readable, err := openAnyFile(pathDebug, func(mode int) (*fd.FD, error) {
return fd.OpenAt(parent.file, name, openFlags|mode, 0)
})
@@ -297,8 +312,8 @@ func openAnyFile(pathDebug string, fn func(mode int) (*fd.FD, error)) (*fd.FD, b
return nil, false, extractErrno(err)
}
-func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error {
- switch stat.Mode & unix.S_IFMT {
+func checkSupportedFileType(mode uint32, permitSocket bool) error {
+ switch mode & unix.S_IFMT {
case unix.S_IFREG, unix.S_IFDIR, unix.S_IFLNK:
return nil
@@ -313,8 +328,8 @@ func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error {
}
}
-func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat unix.Stat_t) (*localFile, error) {
- if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil {
+func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat *unix.Stat_t) (*localFile, error) {
+ if err := checkSupportedFileType(stat.Mode, a.conf.HostUDS); err != nil {
return nil, err
}
@@ -442,8 +457,10 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode,
})
defer cu.Clean()
- if err := fchown(child.FD(), uid, gid); err != nil {
- return nil, nil, p9.QID{}, 0, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(child.FD(), uid, gid); err != nil {
+ return nil, nil, p9.QID{}, 0, extractErrno(err)
+ }
}
stat, err := fstat(child.FD())
if err != nil {
@@ -452,11 +469,11 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode,
c := &localFile{
attachPoint: l.attachPoint,
- hostPath: path.Join(l.hostPath, name),
+ hostPath: join(l.hostPath, name),
file: child,
mode: mode,
fileType: unix.S_IFREG,
- qid: l.attachPoint.makeQID(stat),
+ qid: l.attachPoint.makeQID(&stat),
}
cu.Release()
@@ -488,8 +505,10 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
}
defer f.Close()
- if err := fchown(f.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(f.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
}
stat, err := fstat(f.FD())
if err != nil {
@@ -497,7 +516,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
}
cu.Release()
- return l.attachPoint.makeQID(stat), nil
+ return l.attachPoint.makeQID(&stat), nil
}
// Walk implements p9.File.
@@ -512,7 +531,7 @@ func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask,
if err != nil {
return nil, nil, p9.AttrMask{}, p9.Attr{}, err
}
- mask, attr := l.fillAttr(stat)
+ mask, attr := l.fillAttr(&stat)
return qids, file, mask, attr, nil
}
@@ -538,13 +557,13 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error)
file: newFile,
mode: invalidMode,
fileType: l.fileType,
- qid: l.attachPoint.makeQID(stat),
+ qid: l.attachPoint.makeQID(&stat),
controlReadable: readable,
}
return []p9.QID{c.qid}, c, stat, nil
}
- var qids []p9.QID
+ qids := make([]p9.QID, 0, len(names))
var lastStat unix.Stat_t
last := l
for _, name := range names {
@@ -560,7 +579,7 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error)
_ = f.Close()
return nil, nil, unix.Stat_t{}, extractErrno(err)
}
- c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat)
+ c, err := newLocalFile(last.attachPoint, f, path, readable, &lastStat)
if err != nil {
_ = f.Close()
return nil, nil, unix.Stat_t{}, extractErrno(err)
@@ -609,11 +628,11 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error)
if err != nil {
return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err)
}
- mask, attr := l.fillAttr(stat)
+ mask, attr := l.fillAttr(&stat)
return l.qid, mask, attr, nil
}
-func (l *localFile) fillAttr(stat unix.Stat_t) (p9.AttrMask, p9.Attr) {
+func (l *localFile) fillAttr(stat *unix.Stat_t) (p9.AttrMask, p9.Attr) {
attr := p9.Attr{
Mode: p9.FileMode(stat.Mode),
UID: p9.UID(stat.Uid),
@@ -881,8 +900,10 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.
}
defer f.Close()
- if err := fchown(f.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(f.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
}
stat, err := fstat(f.FD())
if err != nil {
@@ -890,7 +911,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.
}
cu.Release()
- return l.attachPoint.makeQID(stat), nil
+ return l.attachPoint.makeQID(&stat), nil
}
// Link implements p9.File.
@@ -938,8 +959,10 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid
}
defer child.Close()
- if err := fchown(child.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(child.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
}
stat, err := fstat(child.FD())
if err != nil {
@@ -947,7 +970,7 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid
}
cu.Release()
- return l.attachPoint.makeQID(stat), nil
+ return l.attachPoint.makeQID(&stat), nil
}
// UnlinkAt implements p9.File.
@@ -1045,7 +1068,7 @@ func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64)
log.Warningf("Readdir is skipping file with failed stat %q, err: %v", l.hostPath, err)
continue
}
- qid := l.attachPoint.makeQID(stat)
+ qid := l.attachPoint.makeQID(&stat)
offset++
dirents = append(dirents, p9.Dirent{
QID: qid,
@@ -1139,7 +1162,7 @@ func (l *localFile) isOpen() bool {
// Renamed implements p9.Renamed.
func (l *localFile) Renamed(newDir p9.File, newName string) {
- l.hostPath = path.Join(newDir.(*localFile).hostPath, newName)
+ l.hostPath = join(newDir.(*localFile).hostPath, newName)
}
// extractErrno tries to determine the errno.
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index a84206686..c5daebe5e 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -23,10 +23,13 @@ import (
"path/filepath"
"testing"
+ "github.com/syndtr/gocapability/capability"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/specutils"
)
var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite}
@@ -197,10 +200,13 @@ func setup(fileType uint32) (string, string, error) {
switch fileType {
case unix.S_IFREG:
name = "file"
- _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ fd, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
if err != nil {
return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err)
}
+ if fd != nil {
+ fd.Close()
+ }
defer f.Close()
case unix.S_IFDIR:
name = "dir"
@@ -556,7 +562,28 @@ func TestROMountChecks(t *testing.T) {
func TestWalkNotFound(t *testing.T) {
runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) {
if _, _, err := s.file.Walk([]string{"nobody-here"}); err != unix.ENOENT {
- t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: unix.ENOENT", s, "nobody-here", err)
+ t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "nobody-here", err)
+ }
+ if _, _, err := s.file.Walk([]string{"nobody", "here"}); err != unix.ENOENT {
+ t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "nobody/here", err)
+ }
+ if !s.conf.ROMount {
+ if _, err := s.file.Mkdir("dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("MkDir(dir) failed, err: %v", err)
+ }
+ if _, _, err := s.file.Walk([]string{"dir", "nobody-here"}); err != unix.ENOENT {
+ t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "dir/nobody-here", err)
+ }
+ }
+ })
+}
+
+func TestWalkPanic(t *testing.T) {
+ runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) {
+ for _, name := range []string{".", ".."} {
+ assertPanic(t, func() {
+ s.file.Walk([]string{name})
+ })
}
})
}
@@ -574,6 +601,27 @@ func TestWalkDup(t *testing.T) {
})
}
+func TestWalkMultiple(t *testing.T) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ var names []string
+ var parent p9.File = s.file
+ for i := 0; i < 5; i++ {
+ name := fmt.Sprintf("dir%d", i)
+ names = append(names, name)
+
+ if _, err := parent.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("MkDir(%q) failed, err: %v", name, err)
+ }
+
+ var err error
+ _, parent, err = s.file.Walk(names)
+ if err != nil {
+ t.Errorf("Walk(%q): %v", name, err)
+ }
+ }
+ })
+}
+
func TestReaddir(t *testing.T) {
runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
name := "dir"
@@ -819,3 +867,168 @@ func TestMknod(t *testing.T) {
}
})
}
+
+func BenchmarkWalkOne(b *testing.B) {
+ path, name, err := setup(unix.S_IFDIR)
+ if err != nil {
+ b.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
+
+ a, err := NewAttachPoint(path, Config{})
+ if err != nil {
+ b.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ b.Fatalf("Attach failed, err: %v", err)
+ }
+ defer root.Close()
+
+ names := []string{name}
+ files := make([]p9.File, 0, 1000)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, file, err := root.Walk(names)
+ if err != nil {
+ b.Fatalf("Walk(%q): %v", name, err)
+ }
+ files = append(files, file)
+
+ // Avoid running out of FDs.
+ if len(files) == cap(files) {
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ files = files[:0]
+ b.StartTimer()
+ }
+ }
+
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+}
+
+func BenchmarkCreate(b *testing.B) {
+ path, _, err := setup(unix.S_IFDIR)
+ if err != nil {
+ b.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
+
+ a, err := NewAttachPoint(path, Config{})
+ if err != nil {
+ b.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ b.Fatalf("Attach failed, err: %v", err)
+ }
+ defer root.Close()
+
+ files := make([]p9.File, 0, 500)
+ fds := make([]*fd.FD, 0, 500)
+ uid := p9.UID(os.Getuid())
+ gid := p9.GID(os.Getgid())
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ name := fmt.Sprintf("same-%d", i)
+ fd, file, _, _, err := root.Create(name, p9.ReadOnly, 0777, uid, gid)
+ if err != nil {
+ b.Fatalf("Create(%q): %v", name, err)
+ }
+ files = append(files, file)
+ if fd != nil {
+ fds = append(fds, fd)
+ }
+
+ // Avoid running out of FDs.
+ if len(files) == cap(files) {
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ files = files[:0]
+ for _, fd := range fds {
+ fd.Close()
+ }
+ fds = fds[:0]
+ b.StartTimer()
+ }
+ }
+
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ for _, fd := range fds {
+ fd.Close()
+ }
+}
+
+func BenchmarkCreateDiffOwner(b *testing.B) {
+ if !specutils.HasCapabilities(capability.CAP_CHOWN) {
+ b.Skipf("Test requires CAP_CHOWN")
+ }
+
+ path, _, err := setup(unix.S_IFDIR)
+ if err != nil {
+ b.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
+
+ a, err := NewAttachPoint(path, Config{})
+ if err != nil {
+ b.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ b.Fatalf("Attach failed, err: %v", err)
+ }
+ defer root.Close()
+
+ files := make([]p9.File, 0, 500)
+ fds := make([]*fd.FD, 0, 500)
+ gid := p9.GID(os.Getgid())
+ const nobody = 65534
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ name := fmt.Sprintf("diff-%d", i)
+ fd, file, _, _, err := root.Create(name, p9.ReadOnly, 0777, nobody, gid)
+ if err != nil {
+ b.Fatalf("Create(%q): %v", name, err)
+ }
+ files = append(files, file)
+ if fd != nil {
+ fds = append(fds, fd)
+ }
+
+ // Avoid running out of FDs.
+ if len(files) == cap(files) {
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ files = files[:0]
+ for _, fd := range fds {
+ fd.Close()
+ }
+ fds = fds[:0]
+ b.StartTimer()
+ }
+ }
+
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ for _, fd := range fds {
+ fd.Close()
+ }
+}
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index 8f66dd1f8..9e429f7d5 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -127,7 +127,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
// Get all interfaces in the namespace.
ifaces, err := net.Interfaces()
if err != nil {
- return fmt.Errorf("querying interfaces: %v", err)
+ return fmt.Errorf("querying interfaces: %w", err)
}
isRoot, err := isRootNS()
@@ -148,14 +148,14 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
allAddrs, err := iface.Addrs()
if err != nil {
- return fmt.Errorf("fetching interface addresses for %q: %v", iface.Name, err)
+ return fmt.Errorf("fetching interface addresses for %q: %w", iface.Name, err)
}
// We build our own loopback device.
if iface.Flags&net.FlagLoopback != 0 {
link, err := loopbackLink(iface, allAddrs)
if err != nil {
- return fmt.Errorf("getting loopback link for iface %q: %v", iface.Name, err)
+ return fmt.Errorf("getting loopback link for iface %q: %w", iface.Name, err)
}
args.LoopbackLinks = append(args.LoopbackLinks, link)
continue
@@ -209,7 +209,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
// Get the link for the interface.
ifaceLink, err := netlink.LinkByName(iface.Name)
if err != nil {
- return fmt.Errorf("getting link for interface %q: %v", iface.Name, err)
+ return fmt.Errorf("getting link for interface %q: %w", iface.Name, err)
}
link.LinkAddress = ifaceLink.Attrs().HardwareAddr
@@ -219,7 +219,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
log.Debugf("Creating Channel %d", i)
socketEntry, err := createSocket(iface, ifaceLink, hardwareGSO)
if err != nil {
- return fmt.Errorf("failed to createSocket for %s : %v", iface.Name, err)
+ return fmt.Errorf("failed to createSocket for %s : %w", iface.Name, err)
}
if i == 0 {
link.GSOMaxSize = socketEntry.gsoMaxSize
@@ -241,11 +241,12 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
// Collect the addresses for the interface, enable forwarding,
// and remove them from the host.
for _, addr := range ipAddrs {
- link.Addresses = append(link.Addresses, addr.IP)
+ prefix, _ := addr.Mask.Size()
+ link.Addresses = append(link.Addresses, boot.IPWithPrefix{Address: addr.IP, PrefixLen: prefix})
// Steal IP address from NIC.
if err := removeAddress(ifaceLink, addr.String()); err != nil {
- return fmt.Errorf("removing address %v from device %q: %v", iface.Name, addr, err)
+ return fmt.Errorf("removing address %v from device %q: %w", addr, iface.Name, err)
}
}
@@ -254,7 +255,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
log.Debugf("Setting up network, config: %+v", args)
if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &args, nil); err != nil {
- return fmt.Errorf("creating links and routes: %v", err)
+ return fmt.Errorf("creating links and routes: %w", err)
}
return nil
}
@@ -278,8 +279,6 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
ll := syscall.SockaddrLinklayer{
Protocol: protocol,
Ifindex: iface.Index,
- Hatype: 0, // No ARP type.
- Pkttype: syscall.PACKET_OTHERHOST,
}
if err := syscall.Bind(fd, &ll); err != nil {
return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
@@ -339,9 +338,15 @@ func loopbackLink(iface net.Interface, addrs []net.Addr) (boot.LoopbackLink, err
if !ok {
return boot.LoopbackLink{}, fmt.Errorf("address is not IPNet: %+v", addr)
}
+
+ prefix, _ := ipNet.Mask.Size()
+ link.Addresses = append(link.Addresses, boot.IPWithPrefix{
+ Address: ipNet.IP,
+ PrefixLen: prefix,
+ })
+
dst := *ipNet
dst.IP = dst.IP.Mask(dst.Mask)
- link.Addresses = append(link.Addresses, ipNet.IP)
link.Routes = append(link.Routes, boot.Route{
Destination: dst,
})
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 4a4110477..c84ebcd8a 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -173,7 +173,7 @@ func New(conf *config.Config, args *Args) (*Sandbox, error) {
}
// CreateContainer creates a non-root container inside the sandbox.
-func (s *Sandbox) CreateContainer(cid string) error {
+func (s *Sandbox) CreateContainer(cid string, tty *os.File) error {
log.Debugf("Create non-root container %q in sandbox %q, PID: %d", cid, s.ID, s.Pid)
sandboxConn, err := s.sandboxConnect()
if err != nil {
@@ -181,7 +181,16 @@ func (s *Sandbox) CreateContainer(cid string) error {
}
defer sandboxConn.Close()
- if err := sandboxConn.Call(boot.ContainerCreate, &cid, nil); err != nil {
+ var files []*os.File
+ if tty != nil {
+ files = []*os.File{tty}
+ }
+
+ args := boot.CreateArgs{
+ CID: cid,
+ FilePayload: urpc.FilePayload{Files: files},
+ }
+ if err := sandboxConn.Call(boot.ContainerCreate, &args, nil); err != nil {
return fmt.Errorf("creating non-root container %q: %v", cid, err)
}
return nil
@@ -211,11 +220,7 @@ func (s *Sandbox) StartRoot(spec *specs.Spec, conf *config.Config) error {
}
// StartContainer starts running a non-root container inside the sandbox.
-func (s *Sandbox) StartContainer(spec *specs.Spec, conf *config.Config, cid string, goferFiles []*os.File) error {
- for _, f := range goferFiles {
- defer f.Close()
- }
-
+func (s *Sandbox) StartContainer(spec *specs.Spec, conf *config.Config, cid string, stdios, goferFiles []*os.File) error {
log.Debugf("Start non-root container %q in sandbox %q, PID: %d", cid, s.ID, s.Pid)
sandboxConn, err := s.sandboxConnect()
if err != nil {
@@ -223,15 +228,18 @@ func (s *Sandbox) StartContainer(spec *specs.Spec, conf *config.Config, cid stri
}
defer sandboxConn.Close()
- // The payload must container stdin/stdout/stderr followed by gofer
- // files.
- files := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, goferFiles...)
+ // The payload must contain stdin/stdout/stderr (which may be empty if using
+ // TTY) followed by gofer files.
+ payload := urpc.FilePayload{}
+ payload.Files = append(payload.Files, stdios...)
+ payload.Files = append(payload.Files, goferFiles...)
+
// Start running the container.
args := boot.StartArgs{
Spec: spec,
Conf: conf,
CID: cid,
- FilePayload: urpc.FilePayload{Files: files},
+ FilePayload: payload,
}
if err := sandboxConn.Call(boot.ContainerStart, &args, nil); err != nil {
return fmt.Errorf("starting non-root container %v: %v", spec.Process.Args, err)
diff --git a/test/benchmarks/BUILD b/test/benchmarks/BUILD
new file mode 100644
index 000000000..faf310676
--- /dev/null
+++ b/test/benchmarks/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "bzl_library")
+
+package(licenses = ["notice"])
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD
index b4b55317b..697ab5837 100644
--- a/test/benchmarks/base/BUILD
+++ b/test/benchmarks/base/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -14,7 +15,7 @@ go_library(
],
)
-go_test(
+benchmark_test(
name = "startup_test",
size = "enormous",
srcs = ["startup_test.go"],
@@ -26,7 +27,7 @@ go_test(
],
)
-go_test(
+benchmark_test(
name = "size_test",
size = "enormous",
srcs = ["size_test.go"],
@@ -39,7 +40,7 @@ go_test(
],
)
-go_test(
+benchmark_test(
name = "sysbench_test",
size = "enormous",
srcs = ["sysbench_test.go"],
diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD
index 93b380e8a..0b1743603 100644
--- a/test/benchmarks/database/BUILD
+++ b/test/benchmarks/database/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -6,19 +7,13 @@ go_library(
name = "database",
testonly = 1,
srcs = ["database.go"],
- deps = ["//test/benchmarks/harness"],
)
-go_test(
- name = "database_test",
+benchmark_test(
+ name = "redis_test",
size = "enormous",
srcs = ["redis_test.go"],
library = ":database",
- tags = [
- # Requires docker and runsc to be configured before test runs.
- "manual",
- "local",
- ],
visibility = ["//:sandbox"],
deps = [
"//pkg/test/dockerutil",
diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go
index 9eeb59f9a..c15ca661c 100644
--- a/test/benchmarks/database/database.go
+++ b/test/benchmarks/database/database.go
@@ -14,18 +14,3 @@
// Package database holds benchmarks around database applications.
package database
-
-import (
- "os"
- "testing"
-
- "gvisor.dev/gvisor/test/benchmarks/harness"
-)
-
-var h harness.Harness
-
-// TestMain is the main method for package database.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
-}
diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go
index 02e67154e..f8075a04b 100644
--- a/test/benchmarks/database/redis_test.go
+++ b/test/benchmarks/database/redis_test.go
@@ -16,6 +16,7 @@ package database
import (
"context"
+ "os"
"testing"
"time"
@@ -24,6 +25,8 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+var h harness.Harness
+
// All possible operations from redis. Note: "ping" will
// run both PING_INLINE and PING_BUILD.
var operations []string = []string{
@@ -111,21 +114,23 @@ func BenchmarkRedis(b *testing.B) {
// Reset profiles and timer to begin the measurement.
server.RestartProfiles()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- client := clientMachine.GetNativeContainer(ctx, b)
- defer client.CleanUp(ctx)
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/redis",
- }, redis.MakeCmd(ip, serverPort)...)
- if err != nil {
- b.Fatalf("redis-benchmark failed with: %v", err)
- }
-
- // Stop time while we parse results.
- b.StopTimer()
- redis.Report(b, out)
- b.StartTimer()
+ client := clientMachine.GetNativeContainer(ctx, b)
+ defer client.CleanUp(ctx)
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ }, redis.MakeCmd(ip, serverPort, b.N /*requests*/)...)
+ if err != nil {
+ b.Fatalf("redis-benchmark failed with: %v", err)
}
+
+ // Stop time while we parse results.
+ b.StopTimer()
+ redis.Report(b, out)
})
}
}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/defs.bzl b/test/benchmarks/defs.bzl
new file mode 100644
index 000000000..ef44b46e3
--- /dev/null
+++ b/test/benchmarks/defs.bzl
@@ -0,0 +1,14 @@
+"""Defines a rule for benchmark test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def benchmark_test(name, tags = [], **kwargs):
+ go_test(
+ name,
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ "local",
+ "manual",
+ ],
+ **kwargs
+ )
diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD
index 021fae38d..b4f967441 100644
--- a/test/benchmarks/fs/BUILD
+++ b/test/benchmarks/fs/BUILD
@@ -1,8 +1,8 @@
-load("//tools:defs.bzl", "go_test")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
-go_test(
+benchmark_test(
name = "bazel_test",
size = "enormous",
srcs = ["bazel_test.go"],
@@ -14,7 +14,7 @@ go_test(
],
)
-go_test(
+benchmark_test(
name = "fio_test",
size = "enormous",
srcs = ["fio_test.go"],
diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go
index 53ed3f9f2..3fb4da9d1 100644
--- a/test/benchmarks/fs/bazel_test.go
+++ b/test/benchmarks/fs/bazel_test.go
@@ -61,10 +61,10 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
for _, bm := range benchmarks {
pageCache := tools.Parameter{
Name: "page_cache",
- Value: "clean",
+ Value: "dirty",
}
if bm.clearCache {
- pageCache.Value = "dirty"
+ pageCache.Value = "clean"
}
filesystem := tools.Parameter{
@@ -129,12 +129,14 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
if !strings.Contains(got, want) {
b.Fatalf("string %s not in: %s", want, got)
}
- // Clean bazel in case we use b.N.
- _, err = container.Exec(ctx, dockerutil.ExecOpts{
- WorkDir: prefix + workdir,
- }, "bazel", "clean")
- if err != nil {
- b.Fatalf("build failed with: %v", err)
+
+ // Clean bazel in the case we are doing another run.
+ if i < b.N-1 {
+ if _, err = container.Exec(ctx, dockerutil.ExecOpts{
+ WorkDir: prefix + workdir,
+ }, "bazel", "clean"); err != nil {
+ b.Fatalf("build failed with: %v", err)
+ }
}
b.StartTimer()
}
diff --git a/test/benchmarks/harness/harness.go b/test/benchmarks/harness/harness.go
index 5c9d0e01e..4c6e724aa 100644
--- a/test/benchmarks/harness/harness.go
+++ b/test/benchmarks/harness/harness.go
@@ -39,7 +39,7 @@ func (h *Harness) Init() error {
flag.PrintDefaults()
}
flag.Parse()
- if flag.NFlag() == 0 || *help {
+ if *help {
flag.Usage()
os.Exit(0)
}
diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD
index bb242d385..380783f0b 100644
--- a/test/benchmarks/media/BUILD
+++ b/test/benchmarks/media/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -6,12 +7,11 @@ go_library(
name = "media",
testonly = 1,
srcs = ["media.go"],
- deps = ["//test/benchmarks/harness"],
)
-go_test(
- name = "media_test",
- size = "large",
+benchmark_test(
+ name = "ffmpeg_test",
+ size = "enormous",
srcs = ["ffmpeg_test.go"],
library = ":media",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go
index 7822dfad7..a462ec2a6 100644
--- a/test/benchmarks/media/ffmpeg_test.go
+++ b/test/benchmarks/media/ffmpeg_test.go
@@ -15,6 +15,7 @@ package media
import (
"context"
+ "os"
"strings"
"testing"
@@ -22,6 +23,8 @@ import (
"gvisor.dev/gvisor/test/benchmarks/harness"
)
+var h harness.Harness
+
// BenchmarkFfmpeg runs ffmpeg in a container and records runtime.
// BenchmarkFfmpeg should run as root to drop caches.
func BenchmarkFfmpeg(b *testing.B) {
@@ -32,13 +35,13 @@ func BenchmarkFfmpeg(b *testing.B) {
defer machine.CleanUp()
ctx := context.Background()
- container := machine.GetContainer(ctx, b)
- defer container.CleanUp(ctx)
cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ")
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
if err := harness.DropCaches(machine); err != nil {
b.Skipf("failed to drop caches: %v. You probably need root.", err)
}
@@ -51,3 +54,8 @@ func BenchmarkFfmpeg(b *testing.B) {
}
}
}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go
index c7b35b758..ed7b24651 100644
--- a/test/benchmarks/media/media.go
+++ b/test/benchmarks/media/media.go
@@ -14,18 +14,3 @@
// Package media holds benchmarks around media processing applications.
package media
-
-import (
- "os"
- "testing"
-
- "gvisor.dev/gvisor/test/benchmarks/harness"
-)
-
-var h harness.Harness
-
-// TestMain is the main method for package media.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
-}
diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD
index 970f52706..285ec35d9 100644
--- a/test/benchmarks/ml/BUILD
+++ b/test/benchmarks/ml/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -6,12 +7,11 @@ go_library(
name = "ml",
testonly = 1,
srcs = ["ml.go"],
- deps = ["//test/benchmarks/harness"],
)
-go_test(
- name = "ml_test",
- size = "large",
+benchmark_test(
+ name = "tensorflow_test",
+ size = "enormous",
srcs = ["tensorflow_test.go"],
library = ":ml",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go
index 13282d7bb..d5fc5b7da 100644
--- a/test/benchmarks/ml/ml.go
+++ b/test/benchmarks/ml/ml.go
@@ -14,18 +14,3 @@
// Package ml holds benchmarks around machine learning performance.
package ml
-
-import (
- "os"
- "testing"
-
- "gvisor.dev/gvisor/test/benchmarks/harness"
-)
-
-var h harness.Harness
-
-// TestMain is the main method for package ml.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
-}
diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go
index f7746897d..a55329d82 100644
--- a/test/benchmarks/ml/tensorflow_test.go
+++ b/test/benchmarks/ml/tensorflow_test.go
@@ -15,12 +15,15 @@ package ml
import (
"context"
+ "os"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
)
+var h harness.Harness
+
// BenchmarkTensorflow runs workloads from a TensorFlow tutorial.
// See: https://github.com/aymericdamien/TensorFlow-Examples
func BenchmarkTensorflow(b *testing.B) {
@@ -44,12 +47,12 @@ func BenchmarkTensorflow(b *testing.B) {
for name, workload := range workloads {
b.Run(name, func(b *testing.B) {
ctx := context.Background()
- container := machine.GetContainer(ctx, b)
- defer container.CleanUp(ctx)
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
if err := harness.DropCaches(machine); err != nil {
b.Skipf("failed to drop caches: %v. You probably need root.", err)
}
@@ -67,3 +70,8 @@ func BenchmarkTensorflow(b *testing.B) {
}
}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD
index 472b5c387..2741570f5 100644
--- a/test/benchmarks/network/BUILD
+++ b/test/benchmarks/network/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -7,7 +8,6 @@ go_library(
testonly = 1,
srcs = [
"network.go",
- "static_server.go",
],
deps = [
"//pkg/test/dockerutil",
@@ -16,22 +16,74 @@ go_library(
],
)
-go_test(
- name = "network_test",
- size = "large",
+benchmark_test(
+ name = "iperf_test",
+ size = "enormous",
srcs = [
- "httpd_test.go",
"iperf_test.go",
- "nginx_test.go",
+ ],
+ library = ":network",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
+
+benchmark_test(
+ name = "node_test",
+ size = "enormous",
+ srcs = [
"node_test.go",
+ ],
+ library = ":network",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
+
+benchmark_test(
+ name = "ruby_test",
+ size = "enormous",
+ srcs = [
"ruby_test.go",
],
library = ":network",
- tags = [
- # Requires docker and runsc to be configured before test runs.
- "manual",
- "local",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
+
+benchmark_test(
+ name = "nginx_test",
+ size = "enormous",
+ srcs = [
+ "nginx_test.go",
+ ],
+ library = ":network",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
],
+)
+
+benchmark_test(
+ name = "httpd_test",
+ size = "enormous",
+ srcs = [
+ "httpd_test.go",
+ ],
+ library = ":network",
visibility = ["//:sandbox"],
deps = [
"//pkg/test/dockerutil",
diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go
index 8d7d5f750..b07274662 100644
--- a/test/benchmarks/network/httpd_test.go
+++ b/test/benchmarks/network/httpd_test.go
@@ -14,13 +14,17 @@
package network
import (
+ "os"
"strconv"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+var h harness.Harness
+
// see Dockerfile '//images/benchmarks/httpd'.
var httpdDocs = map[string]string{
"notfound": "notfound",
@@ -43,6 +47,22 @@ func BenchmarkReverseHttpd(b *testing.B) {
benchmarkHttpdDocSize(b, true /* reverse */)
}
+// BenchmarkContinuousHttpd runs specific benchmarks for continous jobs.
+// The runtime under test is the server serving a runc client.
+func BenchmarkContinuousHttpd(b *testing.B) {
+ sizes := []string{"10Kb", "100Kb", "1Mb"}
+ threads := []int{1, 25, 100, 1000}
+ benchmarkHttpdContinuous(b, threads, sizes, false /*reverse*/)
+}
+
+// BenchmarkContinuousHttpdReverse runs specific benchmarks for continous jobs.
+// The runtime under test is the client downloading from a runc server.
+func BenchmarkContinuousHttpdReverse(b *testing.B) {
+ sizes := []string{"10Kb", "100Kb", "1Mb"}
+ threads := []int{1, 25, 100, 1000}
+ benchmarkHttpdContinuous(b, threads, sizes, true /*reverse*/)
+}
+
// benchmarkHttpdDocSize iterates through all doc sizes, running subbenchmarks
// for each size.
func benchmarkHttpdDocSize(b *testing.B, reverse bool) {
@@ -62,9 +82,51 @@ func benchmarkHttpdDocSize(b *testing.B, reverse bool) {
if err != nil {
b.Fatalf("Failed to parse parameters: %v", err)
}
+ requests := b.N
+ if requests < c {
+ b.Logf("b.N is %d must be greater than threads %d. Consider running with --test.benchtime=Nx where N >= %d", b.N, c, c)
+ requests = c
+ }
+ b.Run(name, func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: requests,
+ Concurrency: c,
+ Doc: filename,
+ }
+ runHttpd(b, hey, reverse)
+ })
+ }
+ }
+}
+
+// benchmarkHttpdContinuous iterates through given sizes and concurrencies.
+func benchmarkHttpdContinuous(b *testing.B, concurrency []int, sizes []string, reverse bool) {
+ for _, size := range sizes {
+ filename := httpdDocs[size]
+ for _, c := range concurrency {
+ fsize := tools.Parameter{
+ Name: "filesize",
+ Value: size,
+ }
+
+ threads := tools.Parameter{
+ Name: "concurrency",
+ Value: strconv.Itoa(c),
+ }
+
+ name, err := tools.ParametersToName(fsize, threads)
+ if err != nil {
+ b.Fatalf("Failed to parse parameters: %v", err)
+ }
+
+ requests := b.N
+ if requests < c {
+ b.Logf("b.N is %d must be greater than threads %d. Consider running with --test.benchtime=Nx where N >= %d", b.N, c, c)
+ requests = c
+ }
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: c * b.N,
+ Requests: requests,
Concurrency: c,
Doc: filename,
}
@@ -91,5 +153,10 @@ func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) {
},
}
httpdCmd := []string{"sh", "-c", "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"}
- runStaticServer(b, httpdRunOpts, httpdCmd, port, hey, reverse)
+ runStaticServer(b, h, httpdRunOpts, httpdCmd, port, hey, reverse)
+}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go
index b8ab7dfb8..9d64db943 100644
--- a/test/benchmarks/network/iperf_test.go
+++ b/test/benchmarks/network/iperf_test.go
@@ -15,6 +15,7 @@ package network
import (
"context"
+ "os"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
@@ -23,9 +24,11 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+var h harness.Harness
+
func BenchmarkIperf(b *testing.B) {
iperf := tools.Iperf{
- Time: 10, // time in seconds to run client.
+ Time: b.N, // time in seconds to run client.
}
clientMachine, err := h.GetMachine()
@@ -97,17 +100,19 @@ func BenchmarkIperf(b *testing.B) {
// Restart the server profiles. If the server isn't being profiled
// this does nothing.
server.RestartProfiles()
- for i := 0; i < b.N; i++ {
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/iperf",
- }, iperf.MakeCmd(ip, servingPort)...)
- if err != nil {
- b.Fatalf("failed to run client: %v", err)
- }
- b.StopTimer()
- iperf.Report(b, out)
- b.StartTimer()
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/iperf",
+ }, iperf.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("failed to run client: %v", err)
}
+ b.StopTimer()
+ iperf.Report(b, out)
})
}
}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/network/network.go b/test/benchmarks/network/network.go
index ce17ddb94..b18bc2b3c 100644
--- a/test/benchmarks/network/network.go
+++ b/test/benchmarks/network/network.go
@@ -16,16 +16,73 @@
package network
import (
- "os"
+ "context"
"testing"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
)
-var h harness.Harness
+// runStaticServer runs static serving workloads (httpd, nginx).
+func runStaticServer(b *testing.B, h harness.Harness, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) {
+ ctx := context.Background()
-// TestMain is the main method for package network.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
+ // Get two machines: a client and server.
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // Make the containers. 'reverse=true' specifies that the client should use the
+ // runtime under test.
+ var client, server *dockerutil.Container
+ if reverse {
+ client = clientMachine.GetContainer(ctx, b)
+ server = serverMachine.GetNativeContainer(ctx, b)
+ } else {
+ client = clientMachine.GetNativeContainer(ctx, b)
+ server = serverMachine.GetContainer(ctx, b)
+ }
+ defer client.CleanUp(ctx)
+ defer server.CleanUp(ctx)
+
+ // Start the server.
+ if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil {
+ b.Fatalf("failed to start server: %v", err)
+ }
+
+ // Get its IP.
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to find server ip: %v", err)
+ }
+
+ // Get the published port.
+ servingPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to find server port %d: %v", port, err)
+ }
+
+ // Make sure the server is serving.
+ harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
+ b.ResetTimer()
+ server.RestartProfiles()
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/hey",
+ }, hey.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("run failed with: %v", err)
+ }
+
+ b.StopTimer()
+ hey.Report(b, out)
+ b.StartTimer()
}
diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go
index 08565d0b2..87449612a 100644
--- a/test/benchmarks/network/nginx_test.go
+++ b/test/benchmarks/network/nginx_test.go
@@ -14,13 +14,17 @@
package network
import (
+ "os"
"strconv"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+var h harness.Harness
+
// see Dockerfile '//images/benchmarks/nginx'.
var nginxDocs = map[string]string{
"notfound": "notfound",
@@ -44,6 +48,22 @@ func BenchmarkReverseNginxDocSize(b *testing.B) {
benchmarkNginxDocSize(b, true /* reverse */, true /* tmpfs */)
}
+// BenchmarkContinuousNginx runs specific benchmarks for continous jobs.
+// The runtime under test is the sever serving a runc client.
+func BenchmarkContinuousNginx(b *testing.B) {
+ sizes := []string{"10Kb", "100Kb", "1Mb"}
+ threads := []int{1, 25, 100, 1000}
+ benchmarkNginxContinuous(b, threads, sizes, false /*reverse*/)
+}
+
+// BenchmarkContinuousNginxReverse runs specific benchmarks for continous jobs.
+// The runtime under test is the client downloading from a runc server.
+func BenchmarkContinuousNginxReverse(b *testing.B) {
+ sizes := []string{"10Kb", "100Kb", "1Mb"}
+ threads := []int{1, 25, 100, 1000}
+ benchmarkNginxContinuous(b, threads, sizes, true /*reverse*/)
+}
+
// benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks
// for each size.
func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) {
@@ -72,9 +92,14 @@ func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) {
b.Fatalf("Failed to parse parameters: %v", err)
}
+ requests := b.N
+ if requests < c {
+ b.Logf("b.N is %d must be greater than threads %d. Consider running with --test.benchtime=Nx where N >= %d", b.N, c, c)
+ requests = c
+ }
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: c * b.N,
+ Requests: requests,
Concurrency: c,
Doc: filename,
}
@@ -84,6 +109,47 @@ func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) {
}
}
+// benchmarkNginxContinuous iterates through given sizes and concurrencies on a tmpfs mount.
+func benchmarkNginxContinuous(b *testing.B, concurrency []int, sizes []string, reverse bool) {
+ for _, size := range sizes {
+ filename := nginxDocs[size]
+ for _, c := range concurrency {
+ fsize := tools.Parameter{
+ Name: "filesize",
+ Value: size,
+ }
+
+ threads := tools.Parameter{
+ Name: "concurrency",
+ Value: strconv.Itoa(c),
+ }
+
+ fs := tools.Parameter{
+ Name: "filesystem",
+ Value: "tmpfs",
+ }
+
+ name, err := tools.ParametersToName(fsize, threads, fs)
+ if err != nil {
+ b.Fatalf("Failed to parse parameters: %v", err)
+ }
+ requests := b.N
+ if requests < c {
+ b.Logf("b.N is %d must be greater than threads %d. Consider running with --test.benchtime=Nx where N >= %d", b.N, c, c)
+ requests = c
+ }
+ b.Run(name, func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: requests,
+ Concurrency: c,
+ Doc: filename,
+ }
+ runNginx(b, hey, reverse, true /*tmpfs*/)
+ })
+ }
+ }
+}
+
// runNginx configures the static serving methods to run httpd.
func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) {
// nginx runs on port 80.
@@ -99,5 +165,10 @@ func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) {
}
// Command copies nginxDocs to tmpfs serving directory and runs nginx.
- runStaticServer(b, nginxRunOpts, nginxCmd, port, hey, reverse)
+ runStaticServer(b, h, nginxRunOpts, nginxCmd, port, hey, reverse)
+}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go
index 254538899..3e837a9e4 100644
--- a/test/benchmarks/network/node_test.go
+++ b/test/benchmarks/network/node_test.go
@@ -15,6 +15,7 @@ package network
import (
"context"
+ "os"
"strconv"
"testing"
"time"
@@ -24,6 +25,8 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+var h harness.Harness
+
// BenchmarkNode runs requests using 'hey' against a Node server run on
// 'runtime'. The server responds to requests by grabbing some data in a
// redis instance and returns the data in its reponse. The test loops through
@@ -39,9 +42,14 @@ func BenchmarkNode(b *testing.B) {
if err != nil {
b.Fatalf("Failed to parse parameters: %v", err)
}
+ requests := b.N
+ if requests < c {
+ b.Logf("b.N is %d must be greater than threads %d. Consider running with --test.benchtime=Nx where N >= %d", b.N, c, c)
+ requests = c
+ }
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: b.N * c, // Requests b.N requests per thread.
+ Requests: requests,
Concurrency: c,
}
runNode(b, hey)
@@ -131,5 +139,9 @@ func runNode(b *testing.B, hey *tools.Hey) {
// Stop the timer to parse the data and report stats.
b.StopTimer()
hey.Report(b, out)
- b.StartTimer()
+}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go
index 0174ff3f3..c89672873 100644
--- a/test/benchmarks/network/ruby_test.go
+++ b/test/benchmarks/network/ruby_test.go
@@ -16,6 +16,7 @@ package network
import (
"context"
"fmt"
+ "os"
"strconv"
"testing"
"time"
@@ -25,6 +26,8 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+var h harness.Harness
+
// BenchmarkRuby runs requests using 'hey' against a ruby application server.
// On start, ruby app generates some random data and pushes it to a redis
// instance. On a request, the app grabs for random entries from the redis
@@ -40,9 +43,14 @@ func BenchmarkRuby(b *testing.B) {
if err != nil {
b.Fatalf("Failed to parse parameters: %v", err)
}
+ requests := b.N
+ if requests < c {
+ b.Logf("b.N is %d must be greater than threads %d. Consider running with --test.benchtime=Nx where N >= %d", b.N, c, c)
+ requests = c
+ }
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: b.N * c, // b.N requests per thread.
+ Requests: requests,
Concurrency: c,
}
runRuby(b, hey)
@@ -52,7 +60,6 @@ func BenchmarkRuby(b *testing.B) {
// runRuby runs the test for a given # of requests and concurrency.
func runRuby(b *testing.B, hey *tools.Hey) {
- b.Helper()
// The machine to hold Redis and the Ruby Server.
serverMachine, err := h.GetMachine()
if err != nil {
@@ -141,3 +148,8 @@ func runRuby(b *testing.B, hey *tools.Hey) {
hey.Report(b, out)
b.StartTimer()
}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/network/static_server.go b/test/benchmarks/network/static_server.go
deleted file mode 100644
index e747a1395..000000000
--- a/test/benchmarks/network/static_server.go
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package network
-
-import (
- "context"
- "testing"
-
- "gvisor.dev/gvisor/pkg/test/dockerutil"
- "gvisor.dev/gvisor/test/benchmarks/harness"
- "gvisor.dev/gvisor/test/benchmarks/tools"
-)
-
-// runStaticServer runs static serving workloads (httpd, nginx).
-func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) {
- ctx := context.Background()
-
- // Get two machines: a client and server.
- clientMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get machine: %v", err)
- }
- defer clientMachine.CleanUp()
-
- serverMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get machine: %v", err)
- }
- defer serverMachine.CleanUp()
-
- // Make the containers. 'reverse=true' specifies that the client should use the
- // runtime under test.
- var client, server *dockerutil.Container
- if reverse {
- client = clientMachine.GetContainer(ctx, b)
- server = serverMachine.GetNativeContainer(ctx, b)
- } else {
- client = clientMachine.GetNativeContainer(ctx, b)
- server = serverMachine.GetContainer(ctx, b)
- }
- defer client.CleanUp(ctx)
- defer server.CleanUp(ctx)
-
- // Start the server.
- if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil {
- b.Fatalf("failed to start server: %v", err)
- }
-
- // Get its IP.
- ip, err := serverMachine.IPAddress()
- if err != nil {
- b.Fatalf("failed to find server ip: %v", err)
- }
-
- // Get the published port.
- servingPort, err := server.FindPort(ctx, port)
- if err != nil {
- b.Fatalf("failed to find server port %d: %v", port, err)
- }
-
- // Make sure the server is serving.
- harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
- b.ResetTimer()
- server.RestartProfiles()
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/hey",
- }, hey.MakeCmd(ip, servingPort)...)
- if err != nil {
- b.Fatalf("run failed with: %v", err)
- }
-
- b.StopTimer()
- hey.Report(b, out)
- b.StartTimer()
-}
diff --git a/test/benchmarks/tools/iperf.go b/test/benchmarks/tools/iperf.go
index 5c4e7125b..891d32704 100644
--- a/test/benchmarks/tools/iperf.go
+++ b/test/benchmarks/tools/iperf.go
@@ -31,7 +31,7 @@ type Iperf struct {
// MakeCmd returns a iperf client command.
func (i *Iperf) MakeCmd(ip net.IP, port int) []string {
// iperf report in Kb realtime
- return strings.Split(fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", i.Time, ip, port), " ")
+ return strings.Split(fmt.Sprintf("iperf -f K --realtime --time %d --client %s --port %d", i.Time, ip, port), " ")
}
// Report parses output from iperf client and reports metrics.
diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go
index e35886437..a42e3456e 100644
--- a/test/benchmarks/tools/redis.go
+++ b/test/benchmarks/tools/redis.go
@@ -29,17 +29,17 @@ type Redis struct {
}
// MakeCmd returns a redis-benchmark client command.
-func (r *Redis) MakeCmd(ip net.IP, port int) []string {
+func (r *Redis) MakeCmd(ip net.IP, port, requests int) []string {
// There is no -t PING_BULK for redis-benchmark, so adjust the command in that case.
// Note that "ping" will run both PING_INLINE and PING_BULK.
if r.Operation == "PING_BULK" {
return strings.Split(
- fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, port), " ")
+ fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d -n %d", ip, port, requests), " ")
}
// runs redis-benchmark -t operation for 100K requests against server.
return strings.Split(
- fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", r.Operation, ip, port), " ")
+ fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d -n %d", r.Operation, ip, port, requests), " ")
}
// Report parses output from redis-benchmark client and reports metrics.
diff --git a/test/cmd/test_app/fds.go b/test/cmd/test_app/fds.go
index a7658eefd..d4354f0d3 100644
--- a/test/cmd/test_app/fds.go
+++ b/test/cmd/test_app/fds.go
@@ -16,6 +16,7 @@ package main
import (
"context"
+ "io"
"io/ioutil"
"log"
"os"
@@ -168,8 +169,8 @@ func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...int
file := os.NewFile(uintptr(fd), "received file")
defer file.Close()
- if _, err := file.Seek(0, os.SEEK_SET); err != nil {
- log.Fatalf("Seek(0, 0) failed: %v", err)
+ if _, err := file.Seek(0, io.SeekStart); err != nil {
+ log.Fatalf("Error from seek(0, 0): %v", err)
}
got, err := ioutil.ReadAll(file)
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 8425abecb..03bdfa889 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -494,6 +494,55 @@ func TestLink(t *testing.T) {
}
}
+// This test ensures we can run ping without errors.
+func TestPing4Loopback(t *testing.T) {
+ if testutil.IsRunningWithHostNet() {
+ // TODO(gvisor.dev/issue/5011): support ICMP sockets in hostnet and enable
+ // this test.
+ t.Skip("hostnet only supports TCP/UDP sockets, so ping is not supported.")
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ping4test",
+ }, "/root/ping4.sh"); err != nil {
+ t.Fatalf("docker run failed: %s", err)
+ } else if got != "" {
+ t.Errorf("test failed:\n%s", got)
+ }
+}
+
+// This test ensures we can enable ipv6 on loopback and run ping6 without
+// errors.
+func TestPing6Loopback(t *testing.T) {
+ if testutil.IsRunningWithHostNet() {
+ // TODO(gvisor.dev/issue/5011): support ICMP sockets in hostnet and enable
+ // this test.
+ t.Skip("hostnet only supports TCP/UDP sockets, so ping6 is not supported.")
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ping6test",
+ // The CAP_NET_ADMIN capability is required to use the `ip` utility, which
+ // we use to enable ipv6 on loopback.
+ //
+ // By default, ipv6 loopback is not enabled by runsc, because docker does
+ // not assign an ipv6 address to the test container.
+ CapAdd: []string{"NET_ADMIN"},
+ }, "/root/ping6.sh"); err != nil {
+ t.Fatalf("docker run failed: %s", err)
+ } else if got != "" {
+ t.Errorf("test failed:\n%s", got)
+ }
+}
+
func TestMain(m *testing.M) {
dockerutil.EnsureSupportedDockerVersion()
flag.Parse()
diff --git a/test/fuse/BUILD b/test/fuse/BUILD
index 8e31fdd41..74500ec84 100644
--- a/test/fuse/BUILD
+++ b/test/fuse/BUILD
@@ -71,3 +71,8 @@ syscall_test(
fuse = "True",
test = "//test/fuse/linux:setstat_test",
)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:mount_test",
+)
diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD
index 7673252ec..d1fb178e8 100644
--- a/test/fuse/linux/BUILD
+++ b/test/fuse/linux/BUILD
@@ -228,3 +228,15 @@ cc_binary(
"//test/util:test_util",
],
)
+
+cc_binary(
+ name = "mount_test",
+ testonly = 1,
+ srcs = ["mount_test.cc"],
+ deps = [
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/fuse/linux/mount_test.cc b/test/fuse/linux/mount_test.cc
new file mode 100644
index 000000000..a5c2fbb01
--- /dev/null
+++ b/test/fuse/linux/mount_test.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/mount.h>
+
+#include "gtest/gtest.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(FuseMount, FDNotParsable) {
+ int devfd;
+ EXPECT_THAT(devfd = open("/dev/fuse", O_RDWR), SyscallSucceeds());
+ std::string mount_opts = "fd=thiscantbeparsed";
+ TempPath mount_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(mount("fuse", mount_dir.path().c_str(), "fuse",
+ MS_NODEV | MS_NOSUID, mount_opts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index d3e5efd4f..f4af45e96 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -248,7 +248,7 @@ func (FilterOutputOwnerFail) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil {
- return fmt.Errorf("Invalid argument")
+ return fmt.Errorf("invalid argument")
}
return nil
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 834f7615f..4733146c0 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -89,6 +89,10 @@ func iptablesTest(t *testing.T, test TestCase, ipv6 bool) {
// Get the container IP.
ip, err := d.FindIP(ctx, ipv6)
if err != nil {
+ // If ipv6 is not configured, don't fail.
+ if ipv6 && err == dockerutil.ErrNoIP {
+ t.Skipf("No ipv6 address is available.")
+ }
t.Fatalf("failed to get container IP: %v", err)
}
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
index 49642f282..5d95516ee 100644
--- a/test/packetdrill/BUILD
+++ b/test/packetdrill/BUILD
@@ -38,6 +38,15 @@ packetdrill_test(
scripts = ["tcp_defer_accept_timeout.pkt"],
)
+test_suite(
+ name = "all_tests",
+ tags = [
+ "manual",
+ "packetdrill",
+ ],
+ tests = existing_rules(),
+)
+
bzl_library(
name = "defs_bzl",
srcs = ["defs.bzl"],
diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl
index fc28ce9ba..a6cbcc376 100644
--- a/test/packetdrill/defs.bzl
+++ b/test/packetdrill/defs.bzl
@@ -15,7 +15,7 @@ def _packetdrill_test_impl(ctx):
# Make sure that everything is readable here.
"find . -type f -exec chmod a+rx {} \\;",
"find . -type d -exec chmod a+rx {} \\;",
- "%s %s --init_script %s $@ -- %s\n" % (
+ "%s %s --init_script %s \"$@\" -- %s\n" % (
test_runner.short_path,
" ".join(ctx.attr.flags),
ctx.files._init_script[0].short_path,
@@ -80,9 +80,7 @@ def packetdrill_netstack_test(name, **kwargs):
kwargs["tags"] = PACKETDRILL_TAGS
_packetdrill_test(
name = name,
- # This is the default runtime unless
- # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
- flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"],
+ flags = ["--dut_platform", "netstack"],
**kwargs
)
diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh
index 922547d65..d25cad83a 100755
--- a/test/packetdrill/packetdrill_test.sh
+++ b/test/packetdrill/packetdrill_test.sh
@@ -29,7 +29,7 @@ function failure() {
}
trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
-declare -r LONGOPTS="dut_platform:,init_script:,runtime:"
+declare -r LONGOPTS="dut_platform:,init_script:,runtime:,partition:,total_partitions:"
# Don't use declare below so that the error from getopt will end the script.
PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
@@ -48,12 +48,17 @@ while true; do
shift 2
;;
--runtime)
- # Not readonly because there might be multiple --runtime arguments and we
- # want to use just the last one. Only used if --dut_platform is
- # "netstack".
declare RUNTIME="$2"
shift 2
;;
+ --partition)
+ # Ignored.
+ shift 2
+ ;;
+ --total_partitions)
+ # Ignored.
+ shift 2
+ ;;
--)
shift
break
diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD
index 605dd4972..888c44343 100644
--- a/test/packetimpact/runner/BUILD
+++ b/test/packetimpact/runner/BUILD
@@ -32,6 +32,7 @@ go_library(
deps = [
"//pkg/test/dockerutil",
"//test/packetimpact/netdevs",
+ "//test/packetimpact/testbench",
"@com_github_docker_docker//api/types/mount:go_default_library",
],
)
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index 1038e3c8d..c6c95546a 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -12,10 +12,11 @@ def _packetimpact_test_impl(ctx):
# current user, and no other users will be mapped in that namespace.
# Make sure that everything is readable here.
"find . -type f -or -type d -exec chmod a+rx {} \\;",
- "%s %s --testbench_binary %s $@\n" % (
+ "%s %s --testbench_binary %s --num_duts %d $@\n" % (
test_runner.short_path,
" ".join(ctx.attr.flags),
ctx.files.testbench_binary[0].short_path,
+ ctx.attr.num_duts,
),
])
ctx.actions.write(bench, bench_content, is_executable = True)
@@ -51,6 +52,10 @@ _packetimpact_test = rule(
mandatory = False,
default = [],
),
+ "num_duts": attr.int(
+ mandatory = False,
+ default = 1,
+ ),
},
test = True,
implementation = _packetimpact_test_impl,
@@ -110,24 +115,27 @@ def packetimpact_netstack_test(
**kwargs
)
-def packetimpact_go_test(name, expect_native_failure = False, expect_netstack_failure = False):
+def packetimpact_go_test(name, expect_native_failure = False, expect_netstack_failure = False, num_duts = 1):
"""Add packetimpact tests written in go.
Args:
name: name of the test
expect_native_failure: the test must fail natively
expect_netstack_failure: the test must fail for Netstack
+ num_duts: how many DUTs are needed for the test
"""
testbench_binary = name + "_test"
packetimpact_native_test(
name = name,
expect_failure = expect_native_failure,
testbench_binary = testbench_binary,
+ num_duts = num_duts,
)
packetimpact_netstack_test(
name = name,
expect_failure = expect_netstack_failure,
testbench_binary = testbench_binary,
+ num_duts = num_duts,
)
def packetimpact_testbench(name, size = "small", pure = True, **kwargs):
@@ -153,7 +161,7 @@ def packetimpact_testbench(name, size = "small", pure = True, **kwargs):
PacketimpactTestInfo = provider(
doc = "Provide information for packetimpact tests",
- fields = ["name", "expect_netstack_failure"],
+ fields = ["name", "expect_netstack_failure", "num_duts"],
)
ALL_TESTS = [
@@ -216,6 +224,9 @@ ALL_TESTS = [
name = "tcp_user_timeout",
),
PacketimpactTestInfo(
+ name = "tcp_zero_receive_window",
+ ),
+ PacketimpactTestInfo(
name = "tcp_queue_receive_in_syn_sent",
),
PacketimpactTestInfo(
@@ -243,13 +254,9 @@ ALL_TESTS = [
),
PacketimpactTestInfo(
name = "icmpv6_param_problem",
- # TODO(b/153485026): Fix netstack then remove the line below.
- expect_netstack_failure = True,
),
PacketimpactTestInfo(
name = "ipv6_unknown_options_action",
- # TODO(b/159928940): Fix netstack then remove the line below.
- expect_netstack_failure = True,
),
PacketimpactTestInfo(
name = "ipv4_fragment_reassembly",
@@ -259,6 +266,7 @@ ALL_TESTS = [
),
PacketimpactTestInfo(
name = "ipv6_fragment_icmp_error",
+ num_duts = 3,
),
PacketimpactTestInfo(
name = "udp_send_recv_dgram",
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
index 59bb68eb1..3e26c73cb 100644
--- a/test/packetimpact/runner/dut.go
+++ b/test/packetimpact/runner/dut.go
@@ -17,6 +17,7 @@ package runner
import (
"context"
+ "encoding/json"
"flag"
"fmt"
"io/ioutil"
@@ -34,6 +35,7 @@ import (
"github.com/docker/docker/api/types/mount"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/packetimpact/netdevs"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
// stringList implements flag.Value.
@@ -56,9 +58,10 @@ var (
tshark = false
extraTestArgs = stringList{}
expectFailure = false
+ numDUTs = 1
- // DutAddr is the IP addres for DUT.
- DutAddr = net.IPv4(0, 0, 0, 10)
+ // DUTAddr is the IP addres for DUT.
+ DUTAddr = net.IPv4(0, 0, 0, 10)
testbenchAddr = net.IPv4(0, 0, 0, 20)
)
@@ -71,10 +74,15 @@ func RegisterFlags(fs *flag.FlagSet) {
fs.BoolVar(&tshark, "tshark", false, "use more verbose tshark in logs instead of tcpdump")
fs.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench")
fs.BoolVar(&expectFailure, "expect_failure", false, "expect that the test will fail when run")
+ fs.IntVar(&numDUTs, "num_duts", numDUTs, "the number of duts to create")
}
-// CtrlPort is the port that posix_server listens on.
-const CtrlPort = "40000"
+const (
+ // CtrlPort is the port that posix_server listens on.
+ CtrlPort uint16 = 40000
+ // testOutputDir is the directory in each container that holds test output.
+ testOutputDir = "/tmp/testoutput"
+)
// logger implements testutil.Logger.
//
@@ -95,16 +103,21 @@ func (l logger) Logf(format string, args ...interface{}) {
}
}
-// TestWithDUT runs a packetimpact test with the given information.
-func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT, containerAddr net.IP) {
- if testbenchBinary == "" {
- t.Fatal("--testbench_binary is missing")
- }
- dockerutil.EnsureSupportedDockerVersion()
+// dutInfo encapsulates all the essential information to set up testbench
+// container.
+type dutInfo struct {
+ dut DUT
+ ctrlNet, testNet *dockerutil.Network
+ netInfo *testbench.DUTTestNet
+}
- // Create the networks needed for the test. One control network is needed for
- // the gRPC control packets and one test network on which to transmit the test
- // packets.
+// setUpDUT will set up one DUT and return information for setting up the
+// container for testbench.
+func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerutil.Container) DUT) (dutInfo, error) {
+ // Create the networks needed for the test. One control network is needed
+ // for the gRPC control packets and one test network on which to transmit
+ // the test packets.
+ var info dutInfo
ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet"))
testNet := dockerutil.NewNetwork(ctx, logger("testNet"))
for _, dn := range []*dockerutil.Network{ctrlNet, testNet} {
@@ -113,8 +126,8 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
t.Log("creating docker network:", err)
const wait = 100 * time.Millisecond
t.Logf("sleeping %s and will try creating docker network again", wait)
- // This can fail if another docker network claimed the same IP so we'll
- // just try again.
+ // This can fail if another docker network claimed the same IP so we
+ // will just try again.
time.Sleep(wait)
continue
}
@@ -128,115 +141,204 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
})
// Sanity check.
if inspect, err := dn.Inspect(ctx); err != nil {
- t.Fatalf("failed to inspect network %s: %v", dn.Name, err)
+ return dutInfo{}, fmt.Errorf("failed to inspect network %s: %w", dn.Name, err)
} else if inspect.Name != dn.Name {
- t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
+ return dutInfo{}, fmt.Errorf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
}
}
-
- tmpDir, err := ioutil.TempDir("", "container-output")
- if err != nil {
- t.Fatal("creating temp dir:", err)
- }
- t.Cleanup(func() {
- if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
- t.Errorf("unable to copy container output files: %s", err)
- }
- if err := os.RemoveAll(tmpDir); err != nil {
- t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err)
- }
- })
-
- const testOutputDir = "/tmp/testoutput"
+ info.ctrlNet = ctrlNet
+ info.testNet = testNet
// Create the Docker container for the DUT.
- var dut *dockerutil.Container
+ var dut DUT
if native {
- dut = dockerutil.MakeNativeContainer(ctx, logger("dut"))
+ dut = mkDevice(dockerutil.MakeNativeContainer(ctx, logger(fmt.Sprintf("dut-%d", id))))
} else {
- dut = dockerutil.MakeContainer(ctx, logger("dut"))
+ dut = mkDevice(dockerutil.MakeContainer(ctx, logger(fmt.Sprintf("dut-%d", id))))
}
- t.Cleanup(func() {
- dut.CleanUp(ctx)
- })
+ info.dut = dut
runOpts := dockerutil.RunOpts{
Image: "packetimpact",
CapAdd: []string{"NET_ADMIN"},
- Mounts: []mount.Mount{{
- Type: mount.TypeBind,
- Source: tmpDir,
- Target: testOutputDir,
- ReadOnly: false,
- }},
+ }
+ if _, err := MountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil {
+ return dutInfo{}, err
}
- device := mkDevice(dut)
- remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet, containerAddr)
+ ipv4PrefixLength, _ := testNet.Subnet.Mask.Size()
+ remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev, err := dut.Prepare(ctx, t, runOpts, ctrlNet, testNet)
+ if err != nil {
+ return dutInfo{}, err
+ }
+ info.netInfo = &testbench.DUTTestNet{
+ RemoteMAC: remoteMAC,
+ RemoteIPv4: AddressInSubnet(DUTAddr, *testNet.Subnet),
+ RemoteIPv6: remoteIPv6,
+ RemoteDevID: dutDeviceID,
+ RemoteDevName: dutTestNetDev,
+ LocalIPv4: AddressInSubnet(testbenchAddr, *testNet.Subnet),
+ IPv4PrefixLength: ipv4PrefixLength,
+ POSIXServerIP: AddressInSubnet(DUTAddr, *ctrlNet.Subnet),
+ POSIXServerPort: CtrlPort,
+ }
+ return info, nil
+}
- // Create the Docker container for the testbench.
- testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
+// TestWithDUT runs a packetimpact test with the given information.
+func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT) {
+ if testbenchBinary == "" {
+ t.Fatal("--testbench_binary is missing")
+ }
+ dockerutil.EnsureSupportedDockerVersion()
- tbb := path.Base(testbenchBinary)
- containerTestbenchBinary := filepath.Join("/packetimpact", tbb)
- testbench.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb))
-
- // snifferNetDev is a network device on the test orchestrator that we will
- // run sniffer (tcpdump or tshark) on and inject traffic to, not to be
- // confused with the device on the DUT.
- const snifferNetDev = "eth2"
- // Run tcpdump in the test bench unbuffered, without DNS resolution, just on
- // the interface with the test packets.
- snifferArgs := []string{
- "tcpdump",
- "-S", "-vvv", "-U", "-n",
- "-i", snifferNetDev,
- "-w", testOutputDir + "/dump.pcap",
+ dutInfoChan := make(chan dutInfo, numDUTs)
+ errChan := make(chan error, numDUTs)
+ var dockerNetworks []*dockerutil.Network
+ var dutTestNets []*testbench.DUTTestNet
+ var duts []DUT
+
+ setUpCtx, cancelSetup := context.WithCancel(ctx)
+ t.Cleanup(cancelSetup)
+ for i := 0; i < numDUTs; i++ {
+ go func(i int) {
+ info, err := setUpDUT(setUpCtx, t, i, mkDevice)
+ if err != nil {
+ errChan <- err
+ } else {
+ dutInfoChan <- info
+ }
+ }(i)
}
- snifferRegex := "tcpdump: listening.*\n"
- if tshark {
- // Run tshark in the test bench unbuffered, without DNS resolution, just on
- // the interface with the test packets.
- snifferArgs = []string{
- "tshark", "-V", "-l", "-n", "-i", snifferNetDev,
- "-o", "tcp.check_checksum:TRUE",
- "-o", "udp.check_checksum:TRUE",
+ for i := 0; i < numDUTs; i++ {
+ select {
+ case info := <-dutInfoChan:
+ dockerNetworks = append(dockerNetworks, info.ctrlNet, info.testNet)
+ dutTestNets = append(dutTestNets, info.netInfo)
+ duts = append(duts, info.dut)
+ case err := <-errChan:
+ t.Fatal(err)
}
- snifferRegex = "Capturing on.*\n"
}
+ // Create the Docker container for the testbench.
+ testbenchContainer := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
+
+ runOpts := dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ }
+ if _, err := MountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil {
+ t.Fatal(err)
+ }
+ tbb := path.Base(testbenchBinary)
+ containerTestbenchBinary := filepath.Join("/packetimpact", tbb)
+ testbenchContainer.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb))
+
if err := StartContainer(
ctx,
runOpts,
- testbench,
+ testbenchContainer,
testbenchAddr,
- []*dockerutil.Network{ctrlNet, testNet},
- snifferArgs...,
+ dockerNetworks,
+ "tail", "-f", "/dev/null",
); err != nil {
- t.Fatalf("failed to start docker container for testbench sniffer: %s", err)
+ t.Fatalf("cannot start testbench container: %s", err)
}
- // Kill so that it will flush output.
- t.Cleanup(func() {
- time.Sleep(1 * time.Second)
- testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0])
- })
- if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil {
- t.Fatalf("sniffer on %s never listened: %s", dut.Name, err)
+ for i := range dutTestNets {
+ name, info, err := deviceByIP(ctx, testbenchContainer, dutTestNets[i].LocalIPv4)
+ if err != nil {
+ t.Fatalf("failed to get the device name associated with %s: %s", dutTestNets[i].LocalIPv4, err)
+ }
+ dutTestNets[i].LocalDevName = name
+ dutTestNets[i].LocalDevID = info.ID
+ dutTestNets[i].LocalMAC = info.MAC
+ localIPv6, err := getOrAssignIPv6Addr(ctx, testbenchContainer, name)
+ if err != nil {
+ t.Fatalf("failed to get IPV6 address on %s: %s", testbenchContainer.Name, err)
+ }
+ dutTestNets[i].LocalIPv6 = localIPv6
+ }
+ dutTestNetsBytes, err := json.Marshal(dutTestNets)
+ if err != nil {
+ t.Fatalf("failed to marshal %v into json: %s", dutTestNets, err)
}
- // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it
- // will respond with an RST. In most packetimpact tests, the SYN is sent
- // by the raw socket and the kernel knows nothing about the connection, this
- // behavior will break lots of TCP related packetimpact tests. To prevent
- // this, we can install the following iptables rules. The raw socket that
- // packetimpact tests use will still be able to see everything.
- for _, bin := range []string{"iptables", "ip6tables"} {
- if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", snifferNetDev, "-p", "tcp", "-j", "DROP"); err != nil {
- t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs)
+ baseSnifferArgs := []string{
+ "tcpdump",
+ "-vvv",
+ "--absolute-tcp-sequence-numbers",
+ "--packet-buffered",
+ // Disable DNS resolution.
+ "-n",
+ // run tcpdump as root since the output directory is owned by root. From
+ // `man tcpdump`:
+ //
+ // -Z user
+ // --relinquish-privileges=user
+ // If tcpdump is running as root, after opening the capture device
+ // or input savefile, change the user ID to user and the group ID to
+ // the primary group of user.
+ // This behavior is enabled by default (-Z tcpdump), and can be
+ // disabled by -Z root.
+ "-Z", "root",
+ }
+ if tshark {
+ baseSnifferArgs = []string{
+ "tshark",
+ "-V",
+ "-o", "tcp.check_checksum:TRUE",
+ "-o", "udp.check_checksum:TRUE",
+ // Disable buffering.
+ "-l",
+ // Disable DNS resolution.
+ "-n",
+ }
+ }
+ for _, n := range dutTestNets {
+ snifferArgs := append(baseSnifferArgs, "-i", n.LocalDevName)
+ if !tshark {
+ snifferArgs = append(
+ snifferArgs,
+ "-w",
+ filepath.Join(testOutputDir, fmt.Sprintf("%s.pcap", n.LocalDevName)),
+ )
+ }
+ p, err := testbenchContainer.ExecProcess(ctx, dockerutil.ExecOpts{}, snifferArgs...)
+ if err != nil {
+ t.Fatalf("failed to start exec a sniffer on %s: %s", n.LocalDevName, err)
+ }
+ t.Cleanup(func() {
+ if snifferOut, err := p.Logs(); err != nil {
+ t.Errorf("sniffer logs failed: %s\n%s", err, snifferOut)
+ } else {
+ t.Logf("sniffer logs:\n%s", snifferOut)
+ }
+ })
+ // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it
+ // will respond with an RST. In most packetimpact tests, the SYN is sent
+ // by the raw socket, the kernel knows nothing about the connection, this
+ // behavior will break lots of TCP related packetimpact tests. To prevent
+ // this, we can install the following iptables rules. The raw socket that
+ // packetimpact tests use will still be able to see everything.
+ for _, bin := range []string{"iptables", "ip6tables"} {
+ if logs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", n.LocalDevName, "-p", "tcp", "-j", "DROP"); err != nil {
+ t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbenchContainer.Name, err, logs)
+ }
}
}
+ t.Cleanup(func() {
+ // Wait 1 second before killing tcpdump to give it time to flush
+ // any packets. On linux tests killing it immediately can
+ // sometimes result in partial pcaps.
+ time.Sleep(1 * time.Second)
+ if logs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, "killall", baseSnifferArgs[0]); err != nil {
+ t.Errorf("failed to kill all sniffers: %s, logs: %s", err, logs)
+ }
+ })
+
// FIXME(b/156449515): Some piece of the system has a race. The old
// bash script version had a sleep, so we have one too. The race should
// be fixed and this sleep removed.
@@ -248,31 +350,29 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
testArgs := []string{containerTestbenchBinary}
testArgs = append(testArgs, extraTestArgs...)
testArgs = append(testArgs,
- "--posix_server_ip", AddressInSubnet(DutAddr, *ctrlNet.Subnet).String(),
- "--posix_server_port", CtrlPort,
- "--remote_ipv4", AddressInSubnet(DutAddr, *testNet.Subnet).String(),
- "--local_ipv4", AddressInSubnet(testbenchAddr, *testNet.Subnet).String(),
- "--remote_ipv6", remoteIPv6.String(),
- "--remote_mac", remoteMAC.String(),
- "--remote_interface_id", fmt.Sprintf("%d", dutDeviceID),
- "--local_device", snifferNetDev,
- "--remote_device", dutTestNetDev,
fmt.Sprintf("--native=%t", native),
+ "--dut_test_nets_json", string(dutTestNetsBytes),
)
- testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
+ testbenchLogs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
if (err != nil) != expectFailure {
var dutLogs string
- if logs, err := device.Logs(ctx); err != nil {
- dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
- } else {
- dutLogs = logs
+ for i, dut := range duts {
+ logs, err := dut.Logs(ctx)
+ if err != nil {
+ logs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
+ }
+ dutLogs = fmt.Sprintf(`%s====== Begin of DUT-%d Logs ======
+
+%s
+
+====== End of DUT-%d Logs ======
+
+`, dutLogs, i, logs, i)
}
t.Errorf(`test error: %v, expect failure: %t
-%s
-
-====== Begin of Testbench Logs ======
+%s====== Begin of Testbench Logs ======
%s
@@ -285,7 +385,9 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
type DUT interface {
// Prepare prepares the dut, starts posix_server and returns the IPv6, MAC
// address, the interface ID, and the interface name for the testNet on DUT.
- Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string)
+ // The t parameter is supposed to be used for t.Cleanup. Don't use it for
+ // t.Fatal/FailNow functions.
+ Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string, error)
// Logs retrieves the logs from the dut.
Logs(ctx context.Context) (string, error)
}
@@ -303,7 +405,7 @@ func NewDockerDUT(c *dockerutil.Container) DUT {
}
// Prepare implements DUT.Prepare.
-func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) {
+func (dut *DockerDUT) Prepare(ctx context.Context, _ *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string, error) {
const containerPosixServerBinary = "/packetimpact/posix_server"
dut.c.CopyFiles(&runOpts, "/packetimpact", "test/packetimpact/dut/posix_server")
@@ -311,45 +413,31 @@ func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockeru
ctx,
runOpts,
dut.c,
- containerAddr,
+ DUTAddr,
[]*dockerutil.Network{ctrlNet, testNet},
containerPosixServerBinary,
"--ip=0.0.0.0",
- "--port="+CtrlPort,
+ fmt.Sprintf("--port=%d", CtrlPort),
); err != nil {
- t.Fatalf("failed to start docker container for DUT: %s", err)
+ return nil, nil, 0, "", fmt.Errorf("failed to start docker container for DUT: %w", err)
}
if _, err := dut.c.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil {
- t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err)
+ return nil, nil, 0, "", fmt.Errorf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err)
}
- dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet))
+ dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(DUTAddr, *testNet.Subnet))
if err != nil {
- t.Fatal(err)
+ return nil, nil, 0, "", err
}
- remoteMAC := dutDeviceInfo.MAC
- remoteIPv6 := dutDeviceInfo.IPv6Addr
- // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
- // needed.
- if remoteIPv6 == nil {
- if _, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil {
- t.Fatalf("unable to ip addr add on container %s: %s", dut.c.Name, err)
- }
- // Now try again, to make sure that it worked.
- _, dutDeviceInfo, err = deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet))
- if err != nil {
- t.Fatal(err)
- }
- remoteIPv6 = dutDeviceInfo.IPv6Addr
- if remoteIPv6 == nil {
- t.Fatalf("unable to set IPv6 address on container %s", dut.c.Name)
- }
+ remoteIPv6, err := getOrAssignIPv6Addr(ctx, dut.c, dutTestDevice)
+ if err != nil {
+ return nil, nil, 0, "", fmt.Errorf("failed to get IPv6 address on %s: %s", dut.c.Name, err)
}
const testNetDev = "eth2"
- return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev
+ return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev, nil
}
// Logs implements DUT.Logs.
@@ -358,11 +446,7 @@ func (dut *DockerDUT) Logs(ctx context.Context) (string, error) {
if err != nil {
return "", err
}
- return fmt.Sprintf(`====== Begin of DUT Logs ======
-
-%s
-
-====== End of DUT Logs ======`, logs), nil
+ return logs, nil
}
// AddNetworks connects docker network with the container and assigns the specific IP.
@@ -378,25 +462,35 @@ func AddNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, netw
}
// AddressInSubnet combines the subnet provided with the address and returns a
-// new address. The return address bits come from the subnet where the mask is 1
-// and from the ip address where the mask is 0.
+// new address. The return address bits come from the subnet where the mask is
+// 1 and from the ip address where the mask is 0.
func AddressInSubnet(addr net.IP, subnet net.IPNet) net.IP {
- var octets []byte
+ var octets net.IP
for i := 0; i < 4; i++ {
octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i])))
}
- return net.IP(octets)
+ return octets
}
-// deviceByIP finds a deviceInfo and device name from an IP address.
-func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
+// devicesInfo will run "ip addr show" on the container and parse the output
+// to a map[string]netdevs.DeviceInfo.
+func devicesInfo(ctx context.Context, d *dockerutil.Container) (map[string]netdevs.DeviceInfo, error) {
out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show")
if err != nil {
- return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out)
+ return map[string]netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out)
}
devs, err := netdevs.ParseDevices(out)
if err != nil {
- return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out)
+ return map[string]netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out)
+ }
+ return devs, nil
+}
+
+// deviceByIP finds a deviceInfo and device name from an IP address.
+func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
+ devs, err := devicesInfo(ctx, d)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, err
}
testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs)
if err != nil {
@@ -405,6 +499,36 @@ func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string
return testDevice, deviceInfo, nil
}
+// getOrAssignIPv6Addr will try to get the IPv6 address for the interface; if an
+// address was not assigned, a link-local address based on MAC will be assigned
+// to that interface.
+func getOrAssignIPv6Addr(ctx context.Context, d *dockerutil.Container, iface string) (net.IP, error) {
+ devs, err := devicesInfo(ctx, d)
+ if err != nil {
+ return net.IP{}, err
+ }
+ info := devs[iface]
+ if info.IPv6Addr != nil {
+ return info.IPv6Addr, nil
+ }
+ if info.MAC == nil {
+ return nil, fmt.Errorf("unable to find MAC address of %s", iface)
+ }
+ if logs, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(info.MAC).String(), "scope", "link", "dev", iface); err != nil {
+ return net.IP{}, fmt.Errorf("unable to ip addr add on container %s: %w, logs: %s", d.Name, err, logs)
+ }
+ // Now try again, to make sure that it worked.
+ devs, err = devicesInfo(ctx, d)
+ if err != nil {
+ return net.IP{}, err
+ }
+ info = devs[iface]
+ if info.IPv6Addr == nil {
+ return net.IP{}, fmt.Errorf("unable to set IPv6 address on container %s", d.Name)
+ }
+ return info.IPv6Addr, nil
+}
+
// createDockerNetwork makes a randomly-named network that will start with the
// namePrefix. The network will be a random /24 subnet.
func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
@@ -440,3 +564,30 @@ func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerut
}
return nil
}
+
+// MountTempDirectory creates a temporary directory on host with the template
+// and then mounts it into the container under the name provided. The temporary
+// directory name is returned. Content in that directory will be copied to
+// TEST_UNDECLARED_OUTPUTS_DIR in cleanup phase.
+func MountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) {
+ t.Helper()
+ tmpDir, err := ioutil.TempDir("", hostDirTemplate)
+ if err != nil {
+ return "", fmt.Errorf("failed to create a temp dir: %w", err)
+ }
+ t.Cleanup(func() {
+ if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
+ t.Errorf("unable to copy container output files: %s", err)
+ }
+ if err := os.RemoveAll(tmpDir); err != nil {
+ t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err)
+ }
+ })
+ runOpts.Mounts = append(runOpts.Mounts, mount.Mount{
+ Type: mount.TypeBind,
+ Source: tmpDir,
+ Target: containerDir,
+ ReadOnly: false,
+ })
+ return tmpDir, nil
+}
diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go
index c598bfc29..46334b7ab 100644
--- a/test/packetimpact/runner/packetimpact_test.go
+++ b/test/packetimpact/runner/packetimpact_test.go
@@ -28,5 +28,5 @@ func init() {
}
func TestOne(t *testing.T) {
- runner.TestWithDUT(context.Background(), t, runner.NewDockerDUT, runner.DutAddr)
+ runner.TestWithDUT(context.Background(), t, runner.NewDockerDUT)
}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 5a0ee1367..983c2c030 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -21,7 +21,6 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
"//pkg/usermem",
- "//test/packetimpact/netdevs",
"//test/packetimpact/proto:posix_server_go_proto",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 919b4fd25..576577310 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -17,7 +17,6 @@ package testbench
import (
"fmt"
"math/rand"
- "net"
"testing"
"time"
@@ -42,7 +41,7 @@ func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
// pickPort makes a new socket and returns the socket FD and port. The domain
// should be AF_INET or AF_INET6. The caller must close the FD when done with
// the port if there is no error.
-func pickPort(domain, typ int) (fd int, port uint16, err error) {
+func (n *DUTTestNet) pickPort(domain, typ int) (fd int, port uint16, err error) {
fd, err = unix.Socket(domain, typ, 0)
if err != nil {
return -1, 0, fmt.Errorf("creating socket: %w", err)
@@ -58,11 +57,11 @@ func pickPort(domain, typ int) (fd int, port uint16, err error) {
switch domain {
case unix.AF_INET:
var sa4 unix.SockaddrInet4
- copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4())
+ copy(sa4.Addr[:], n.LocalIPv4)
sa = &sa4
case unix.AF_INET6:
- sa6 := unix.SockaddrInet6{ZoneId: uint32(LocalInterfaceID)}
- copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16())
+ sa6 := unix.SockaddrInet6{ZoneId: n.LocalDevID}
+ copy(sa6.Addr[:], n.LocalIPv6)
sa = &sa6
default:
return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
@@ -117,19 +116,12 @@ type etherState struct {
var _ layerState = (*etherState)(nil)
// newEtherState creates a new etherState.
-func newEtherState(out, in Ether) (*etherState, error) {
- lMAC, err := tcpip.ParseMACAddress(LocalMAC)
- if err != nil {
- return nil, fmt.Errorf("parsing local MAC: %q: %w", LocalMAC, err)
- }
-
- rMAC, err := tcpip.ParseMACAddress(RemoteMAC)
- if err != nil {
- return nil, fmt.Errorf("parsing remote MAC: %q: %w", RemoteMAC, err)
- }
+func (n *DUTTestNet) newEtherState(out, in Ether) (*etherState, error) {
+ lmac := tcpip.LinkAddress(n.LocalMAC)
+ rmac := tcpip.LinkAddress(n.RemoteMAC)
s := etherState{
- out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
- in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
+ out: Ether{SrcAddr: &lmac, DstAddr: &rmac},
+ in: Ether{SrcAddr: &rmac, DstAddr: &lmac},
}
if err := s.out.merge(&out); err != nil {
return nil, err
@@ -169,9 +161,9 @@ type ipv4State struct {
var _ layerState = (*ipv4State)(nil)
// newIPv4State creates a new ipv4State.
-func newIPv4State(out, in IPv4) (*ipv4State, error) {
- lIP := tcpip.Address(net.ParseIP(LocalIPv4).To4())
- rIP := tcpip.Address(net.ParseIP(RemoteIPv4).To4())
+func (n *DUTTestNet) newIPv4State(out, in IPv4) (*ipv4State, error) {
+ lIP := tcpip.Address(n.LocalIPv4)
+ rIP := tcpip.Address(n.RemoteIPv4)
s := ipv4State{
out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
@@ -214,9 +206,9 @@ type ipv6State struct {
var _ layerState = (*ipv6State)(nil)
// newIPv6State creates a new ipv6State.
-func newIPv6State(out, in IPv6) (*ipv6State, error) {
- lIP := tcpip.Address(net.ParseIP(LocalIPv6).To16())
- rIP := tcpip.Address(net.ParseIP(RemoteIPv6).To16())
+func (n *DUTTestNet) newIPv6State(out, in IPv6) (*ipv6State, error) {
+ lIP := tcpip.Address(n.LocalIPv6)
+ rIP := tcpip.Address(n.RemoteIPv6)
s := ipv6State{
out: IPv6{SrcAddr: &lIP, DstAddr: &rIP},
in: IPv6{SrcAddr: &rIP, DstAddr: &lIP},
@@ -272,8 +264,8 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value {
}
// newTCPState creates a new TCPState.
-func newTCPState(domain int, out, in TCP) (*tcpState, error) {
- portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
+func (n *DUTTestNet) newTCPState(domain int, out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_STREAM)
if err != nil {
return nil, err
}
@@ -314,11 +306,11 @@ func (s *tcpState) incoming(received Layer) Layer {
if s.remoteSeqNum != nil {
newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
}
- if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ if seq, flags := s.localSeqNum, tcpReceived.Flags; seq != nil && flags != nil && *flags&header.TCPFlagAck != 0 {
// The caller didn't specify an AckNum so we'll expect the calculated one,
// but only if the ACK flag is set because the AckNum is not valid in a
// header if ACK is not set.
- newIn.AckNum = Uint32(uint32(*s.localSeqNum))
+ newIn.AckNum = Uint32(uint32(*seq))
}
return &newIn
}
@@ -376,8 +368,8 @@ type udpState struct {
var _ layerState = (*udpState)(nil)
// newUDPState creates a new udpState.
-func newUDPState(domain int, out, in UDP) (*udpState, error) {
- portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
+func (n *DUTTestNet) newUDPState(domain int, out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_DGRAM)
if err != nil {
return nil, fmt.Errorf("picking port: %w", err)
}
@@ -606,14 +598,14 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du
var errs error
for {
var gotLayers Layers
- if timeout = time.Until(deadline); timeout > 0 {
+ if timeout := time.Until(deadline); timeout > 0 {
gotLayers = conn.recvFrame(t, timeout)
}
if gotLayers == nil {
if errs == nil {
- return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout)
+ return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout)
}
- return nil, fmt.Errorf("got frames %w want %v during %s", errs, layers, timeout)
+ return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout)
}
if conn.match(layers, gotLayers) {
for i, s := range conn.layerStates {
@@ -623,7 +615,12 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du
}
return gotLayers, nil
}
- errs = multierr.Combine(errs, &layersError{got: gotLayers, want: conn.incoming(gotLayers)})
+ want := conn.incoming(layers)
+ if err := want.merge(layers); err != nil {
+ errs = multierr.Combine(errs, err)
+ } else {
+ errs = multierr.Combine(errs, &layersError{got: gotLayers, want: want})
+ }
}
}
@@ -639,26 +636,26 @@ func (conn *Connection) Drain(t *testing.T) {
type TCPIPv4 Connection
// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
-func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+func (n *DUTTestNet) NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
t.Helper()
- etherState, err := newEtherState(Ether{}, Ether{})
+ etherState, err := n.newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
}
- ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ ipv4State, err := n.newIPv4State(IPv4{}, IPv4{})
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
+ tcpState, err := n.newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
if err != nil {
t.Fatalf("can't make tcpState: %s", err)
}
- injector, err := NewInjector(t)
+ injector, err := n.NewInjector(t)
if err != nil {
t.Fatalf("can't make injector: %s", err)
}
- sniffer, err := NewSniffer(t)
+ sniffer, err := n.NewSniffer(t)
if err != nil {
t.Fatalf("can't make sniffer: %s", err)
}
@@ -841,23 +838,23 @@ func (conn *TCPIPv4) Drain(t *testing.T) {
type IPv4Conn Connection
// NewIPv4Conn creates a new IPv4Conn connection with reasonable defaults.
-func NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn {
+func (n *DUTTestNet) NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn {
t.Helper()
- etherState, err := newEtherState(Ether{}, Ether{})
+ etherState, err := n.newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make EtherState: %s", err)
}
- ipv4State, err := newIPv4State(outgoingIPv4, incomingIPv4)
+ ipv4State, err := n.newIPv4State(outgoingIPv4, incomingIPv4)
if err != nil {
t.Fatalf("can't make IPv4State: %s", err)
}
- injector, err := NewInjector(t)
+ injector, err := n.NewInjector(t)
if err != nil {
t.Fatalf("can't make injector: %s", err)
}
- sniffer, err := NewSniffer(t)
+ sniffer, err := n.NewSniffer(t)
if err != nil {
t.Fatalf("can't make sniffer: %s", err)
}
@@ -896,23 +893,23 @@ func (c *IPv4Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration
type IPv6Conn Connection
// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults.
-func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
+func (n *DUTTestNet) NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
t.Helper()
- etherState, err := newEtherState(Ether{}, Ether{})
+ etherState, err := n.newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make EtherState: %s", err)
}
- ipv6State, err := newIPv6State(outgoingIPv6, incomingIPv6)
+ ipv6State, err := n.newIPv6State(outgoingIPv6, incomingIPv6)
if err != nil {
t.Fatalf("can't make IPv6State: %s", err)
}
- injector, err := NewInjector(t)
+ injector, err := n.NewInjector(t)
if err != nil {
t.Fatalf("can't make injector: %s", err)
}
- sniffer, err := NewSniffer(t)
+ sniffer, err := n.NewSniffer(t)
if err != nil {
t.Fatalf("can't make sniffer: %s", err)
}
@@ -951,26 +948,26 @@ func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Durat
type UDPIPv4 Connection
// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
-func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+func (n *DUTTestNet) NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
t.Helper()
- etherState, err := newEtherState(Ether{}, Ether{})
+ etherState, err := n.newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
}
- ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ ipv4State, err := n.newIPv4State(IPv4{}, IPv4{})
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
+ udpState, err := n.newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
if err != nil {
t.Fatalf("can't make udpState: %s", err)
}
- injector, err := NewInjector(t)
+ injector, err := n.NewInjector(t)
if err != nil {
t.Fatalf("can't make injector: %s", err)
}
- sniffer, err := NewSniffer(t)
+ sniffer, err := n.NewSniffer(t)
if err != nil {
t.Fatalf("can't make sniffer: %s", err)
}
@@ -1075,26 +1072,26 @@ func (conn *UDPIPv4) Drain(t *testing.T) {
type UDPIPv6 Connection
// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults.
-func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 {
+func (n *DUTTestNet) NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 {
t.Helper()
- etherState, err := newEtherState(Ether{}, Ether{})
+ etherState, err := n.newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
}
- ipv6State, err := newIPv6State(IPv6{}, IPv6{})
+ ipv6State, err := n.newIPv6State(IPv6{}, IPv6{})
if err != nil {
t.Fatalf("can't make IPv6State: %s", err)
}
- udpState, err := newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP)
+ udpState, err := n.newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP)
if err != nil {
t.Fatalf("can't make udpState: %s", err)
}
- injector, err := NewInjector(t)
+ injector, err := n.NewInjector(t)
if err != nil {
t.Fatalf("can't make injector: %s", err)
}
- sniffer, err := NewSniffer(t)
+ sniffer, err := n.NewSniffer(t)
if err != nil {
t.Fatalf("can't make sniffer: %s", err)
}
@@ -1126,14 +1123,14 @@ func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State {
}
// LocalAddr gets the local socket address of this connection.
-func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 {
+func (conn *UDPIPv6) LocalAddr(t *testing.T, zoneID uint32) *unix.SockaddrInet6 {
t.Helper()
sa := &unix.SockaddrInet6{
Port: int(*conn.udpState(t).out.SrcPort),
// Local address is in perspective to the remote host, so it's scoped to the
// ID of the remote interface.
- ZoneId: uint32(RemoteInterfaceID),
+ ZoneId: zoneID,
}
copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr)
return sa
@@ -1203,24 +1200,24 @@ func (conn *UDPIPv6) Drain(t *testing.T) {
type TCPIPv6 Connection
// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults.
-func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 {
- etherState, err := newEtherState(Ether{}, Ether{})
+func (n *DUTTestNet) NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 {
+ etherState, err := n.newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
}
- ipv6State, err := newIPv6State(IPv6{}, IPv6{})
+ ipv6State, err := n.newIPv6State(IPv6{}, IPv6{})
if err != nil {
t.Fatalf("can't make ipv6State: %s", err)
}
- tcpState, err := newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP)
+ tcpState, err := n.newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP)
if err != nil {
t.Fatalf("can't make tcpState: %s", err)
}
- injector, err := NewInjector(t)
+ injector, err := n.NewInjector(t)
if err != nil {
t.Fatalf("can't make injector: %s", err)
}
- sniffer, err := NewSniffer(t)
+ sniffer, err := n.NewSniffer(t)
if err != nil {
t.Fatalf("can't make sniffer: %s", err)
}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index 6165ab293..66a0255b8 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -17,9 +17,8 @@ package testbench
import (
"context"
"encoding/binary"
- "flag"
+ "fmt"
"net"
- "strconv"
"syscall"
"testing"
"time"
@@ -35,18 +34,26 @@ import (
type DUT struct {
conn *grpc.ClientConn
posixServer POSIXClient
+ Net *DUTTestNet
}
// NewDUT creates a new connection with the DUT over gRPC.
func NewDUT(t *testing.T) DUT {
t.Helper()
+ n := GetDUTTestNet()
+ dut := n.ConnectToDUT(t)
+ t.Cleanup(func() {
+ dut.TearDownConnection()
+ dut.Net.Release()
+ })
+ return dut
+}
- flag.Parse()
- if err := genPseudoFlags(); err != nil {
- t.Fatal("generating psuedo flags:", err)
- }
+// ConnectToDUT connects to DUT through gRPC.
+func (n *DUTTestNet) ConnectToDUT(t *testing.T) DUT {
+ t.Helper()
- posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort)
+ posixServerAddress := net.JoinHostPort(n.POSIXServerIP.String(), fmt.Sprintf("%d", n.POSIXServerPort))
conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive}))
if err != nil {
t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err)
@@ -55,11 +62,12 @@ func NewDUT(t *testing.T) DUT {
return DUT{
conn: conn,
posixServer: posixServer,
+ Net: n,
}
}
-// TearDown closes the underlying connection.
-func (dut *DUT) TearDown() {
+// TearDownConnection closes the underlying connection.
+func (dut *DUT) TearDownConnection() {
dut.conn.Close()
}
@@ -132,7 +140,7 @@ func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (
fd = dut.Socket(t, unix.AF_INET6, typ, proto)
sa := unix.SockaddrInet6{}
copy(sa.Addr[:], addr.To16())
- sa.ZoneId = uint32(RemoteInterfaceID)
+ sa.ZoneId = dut.Net.RemoteDevID
dut.Bind(t, fd, &sa)
} else {
t.Fatalf("invalid IP address: %s", addr)
@@ -154,7 +162,7 @@ func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (
func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) {
t.Helper()
- fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4))
+ fd, remotePort := dut.CreateBoundSocket(t, typ, proto, dut.Net.RemoteIPv4)
dut.Listen(t, fd, backlog)
return fd, remotePort
}
@@ -717,9 +725,9 @@ func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Dur
dut.SetSockOpt(t, sockfd, unix.SOL_SOCKET, unix.SO_LINGER, buf)
}
-// Shutdown calls shutdown on the DUT and causes a fatal test failure if it doesn't
-// succeed. If more control over the timeout or error handling is needed, use
-// ShutdownWithErrno.
+// Shutdown calls shutdown on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use ShutdownWithErrno.
func (dut *DUT) Shutdown(t *testing.T, fd, how int32) error {
t.Helper()
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 7401a1991..19e6b8d7d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -298,14 +298,12 @@ func (l *IPv4) ToBytes() ([]byte, error) {
// An IPv4 header is variable length depending on the size of the Options.
hdrLen := header.IPv4MinimumSize
if l.Options != nil {
- hdrLen += l.Options.SizeWithPadding()
+ if len(*l.Options)%4 != 0 {
+ return nil, fmt.Errorf("invalid header options '%x (len=%d)'; must be 32 bit aligned", *l.Options, len(*l.Options))
+ }
+ hdrLen += len(*l.Options)
if hdrLen > header.IPv4MaximumHeaderSize {
- // While ToBytes can be called on packets that were received as well
- // as packets locally generated, it is physically impossible for a
- // received packet to overflow this value so any such failure must
- // be the result of a local programming error and not remotely
- // triggered. A panic is therefore appropriate.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", len(*l.Options), header.IPv4MaximumOptionsSize))
+ return nil, fmt.Errorf("IPv4 Options %d bytes, Max %d", len(*l.Options), header.IPv4MaximumOptionsSize)
}
}
b := make([]byte, hdrLen)
@@ -323,10 +321,6 @@ func (l *IPv4) ToBytes() ([]byte, error) {
DstAddr: tcpip.Address(""),
Options: nil,
}
- // Leave an empty options slice as nil.
- if hdrLen > header.IPv4MinimumSize {
- fields.Options = *l.Options
- }
if l.TOS != nil {
fields.TOS = *l.TOS
}
@@ -373,18 +367,31 @@ func (l *IPv4) ToBytes() ([]byte, error) {
if l.DstAddr != nil {
fields.DstAddr = *l.DstAddr
}
- if l.Checksum != nil {
- fields.Checksum = *l.Checksum
- }
+
h.Encode(fields)
- if l.Checksum == nil {
- h.SetChecksum(^h.CalculateChecksum())
+
+ // Put raw option bytes from test definition in header. Options as raw bytes
+ // allows us to serialize malformed options, which is not possible with
+ // the provided serialization functions.
+ if l.Options != nil {
+ h.SetHeaderLength(h.HeaderLength() + uint8(len(*l.Options)))
+ if got, want := copy(h.Options(), *l.Options), len(*l.Options); got != want {
+ return nil, fmt.Errorf("failed to copy option bytes into header, got %d want %d", got, want)
+ }
}
+
// Encode cannot set this incorrectly so we need to overwrite what it wrote
// in order to test handling of a bad IHL value.
if l.IHL != nil {
h.SetHeaderLength(*l.IHL)
}
+
+ if l.Checksum == nil {
+ h.SetChecksum(^h.CalculateChecksum())
+ } else {
+ h.SetChecksum(*l.Checksum)
+ }
+
return h, nil
}
@@ -498,13 +505,13 @@ func (l *IPv6) ToBytes() ([]byte, error) {
}
}
if l.NextHeader != nil {
- fields.NextHeader = *l.NextHeader
+ fields.TransportProtocol = tcpip.TransportProtocolNumber(*l.NextHeader)
} else {
nh, err := nextHeaderByLayer(l.next())
if err != nil {
return nil, err
}
- fields.NextHeader = nh
+ fields.TransportProtocol = tcpip.TransportProtocolNumber(nh)
}
if l.HopLimit != nil {
fields.HopLimit = *l.HopLimit
@@ -830,7 +837,9 @@ func (l *ICMPv6) ToBytes() ([]byte, error) {
if l.Code != nil {
h.SetCode(*l.Code)
}
- copy(h.NDPPayload(), l.Payload)
+ if n := copy(h.MessageBody(), l.Payload); n != len(l.Payload) {
+ panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, len(l.Payload)))
+ }
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
} else {
@@ -876,7 +885,7 @@ func parseICMPv6(b []byte) (Layer, layerParser) {
Type: ICMPv6Type(h.Type()),
Code: ICMPv6Code(h.Code()),
Checksum: Uint16(h.Checksum()),
- Payload: h.NDPPayload(),
+ Payload: h.MessageBody(),
}
return &icmpv6, nil
}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 193bb2dc8..1ac96626a 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -38,13 +38,27 @@ func htons(x uint16) uint16 {
}
// NewSniffer creates a Sniffer connected to *device.
-func NewSniffer(t *testing.T) (Sniffer, error) {
+func (n *DUTTestNet) NewSniffer(t *testing.T) (Sniffer, error) {
t.Helper()
+ ifInfo, err := net.InterfaceByName(n.LocalDevName)
+ if err != nil {
+ return Sniffer{}, err
+ }
+
+ var haddr [8]byte
+ copy(haddr[:], ifInfo.HardwareAddr)
+ sa := unix.SockaddrLinklayer{
+ Protocol: htons(unix.ETH_P_ALL),
+ Ifindex: ifInfo.Index,
+ }
snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
if err != nil {
return Sniffer{}, err
}
+ if err := unix.Bind(snifferFd, &sa); err != nil {
+ return Sniffer{}, err
+ }
if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil {
t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err)
}
@@ -60,7 +74,8 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
// packet too large for the buffer arrives, the test will get a fatal error.
const maxReadSize int = 65536
-// Recv tries to read one frame until the timeout is up.
+// Recv tries to read one frame until the timeout is up. If the timeout given
+// is 0, then no read attempt will be made.
func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte {
t.Helper()
@@ -73,9 +88,13 @@ func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte {
whole, frac := math.Modf(timeout.Seconds())
tv := unix.Timeval{
Sec: int64(whole),
- Usec: int64(frac * float64(time.Microsecond/time.Second)),
+ Usec: int64(frac * float64(time.Second/time.Microsecond)),
+ }
+ // The following should never happen, but having this guard here is better
+ // than blocking indefinitely in the future.
+ if tv.Sec == 0 && tv.Usec == 0 {
+ t.Fatal("setting SO_RCVTIMEO to 0 means blocking indefinitely")
}
-
if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
}
@@ -136,10 +155,10 @@ type Injector struct {
}
// NewInjector creates a new injector on *device.
-func NewInjector(t *testing.T) (Injector, error) {
+func (n *DUTTestNet) NewInjector(t *testing.T) (Injector, error) {
t.Helper()
- ifInfo, err := net.InterfaceByName(LocalDevice)
+ ifInfo, err := net.InterfaceByName(n.LocalDevName)
if err != nil {
return Injector{}, err
}
@@ -147,7 +166,7 @@ func NewInjector(t *testing.T) (Injector, error) {
var haddr [8]byte
copy(haddr[:], ifInfo.HardwareAddr)
sa := unix.SockaddrLinklayer{
- Protocol: unix.ETH_P_IP,
+ Protocol: htons(unix.ETH_P_IP),
Ifindex: ifInfo.Index,
Halen: uint8(len(ifInfo.HardwareAddr)),
Addr: haddr,
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
index c1db95d8c..891897d55 100644
--- a/test/packetimpact/testbench/testbench.go
+++ b/test/packetimpact/testbench/testbench.go
@@ -17,108 +17,105 @@
package testbench
import (
+ "encoding/json"
"flag"
"fmt"
"math/rand"
"net"
- "os/exec"
"testing"
"time"
-
- "gvisor.dev/gvisor/test/packetimpact/netdevs"
)
var (
// Native indicates that the test is being run natively.
Native = false
- // LocalDevice is the device that testbench uses to inject traffic.
- LocalDevice = ""
- // RemoteDevice is the device name on the DUT, individual tests can
- // use the name to construct tests.
- RemoteDevice = ""
+ // RPCKeepalive is the gRPC keepalive.
+ RPCKeepalive = 10 * time.Second
+ // RPCTimeout is the gRPC timeout.
+ RPCTimeout = 100 * time.Millisecond
+
+ // dutTestNetsJSON is the json string that describes all the test networks to
+ // duts available to use.
+ dutTestNetsJSON string
+ // dutTestNets is the pool among which the testbench can choose a DUT to work
+ // with.
+ dutTestNets chan *DUTTestNet
+)
+// DUTTestNet describes the test network setup on dut and how the testbench
+// should connect with an existing DUT.
+type DUTTestNet struct {
+ // LocalMAC is the local MAC address on the test network.
+ LocalMAC net.HardwareAddr
+ // RemoteMAC is the DUT's MAC address on the test network.
+ RemoteMAC net.HardwareAddr
// LocalIPv4 is the local IPv4 address on the test network.
- LocalIPv4 = ""
+ LocalIPv4 net.IP
// RemoteIPv4 is the DUT's IPv4 address on the test network.
- RemoteIPv4 = ""
+ RemoteIPv4 net.IP
// IPv4PrefixLength is the network prefix length of the IPv4 test network.
- IPv4PrefixLength = 0
-
+ IPv4PrefixLength int
// LocalIPv6 is the local IPv6 address on the test network.
- LocalIPv6 = ""
+ LocalIPv6 net.IP
// RemoteIPv6 is the DUT's IPv6 address on the test network.
- RemoteIPv6 = ""
+ RemoteIPv6 net.IP
+ // LocalDevID is the ID of the local interface on the test network.
+ LocalDevID uint32
+ // RemoteDevID is the ID of the remote interface on the test network.
+ RemoteDevID uint32
+ // LocalDevName is the device that testbench uses to inject traffic.
+ LocalDevName string
+ // RemoteDevName is the device name on the DUT, individual tests can
+ // use the name to construct tests.
+ RemoteDevName string
- // LocalInterfaceID is the ID of the local interface on the test network.
- LocalInterfaceID uint32
- // RemoteInterfaceID is the ID of the remote interface on the test network.
- //
- // Not using uint32 because package flag does not support uint32.
- RemoteInterfaceID uint64
-
- // LocalMAC is the local MAC address on the test network.
- LocalMAC = ""
- // RemoteMAC is the DUT's MAC address on the test network.
- RemoteMAC = ""
+ // The following two fields on actually on the control network instead
+ // of the test network, including them for convenience.
// POSIXServerIP is the POSIX server's IP address on the control network.
- POSIXServerIP = ""
+ POSIXServerIP net.IP
// POSIXServerPort is the UDP port the POSIX server is bound to on the
// control network.
- POSIXServerPort = 40000
-
- // RPCKeepalive is the gRPC keepalive.
- RPCKeepalive = 10 * time.Second
- // RPCTimeout is the gRPC timeout.
- RPCTimeout = 100 * time.Millisecond
-)
+ POSIXServerPort uint16
+}
-// RegisterFlags defines flags and associates them with the package-level
+// registerFlags defines flags and associates them with the package-level
// exported variables above. It should be called by tests in their init
// functions.
-func RegisterFlags(fs *flag.FlagSet) {
- fs.StringVar(&POSIXServerIP, "posix_server_ip", POSIXServerIP, "ip address to listen to for UDP commands")
- fs.IntVar(&POSIXServerPort, "posix_server_port", POSIXServerPort, "port to listen to for UDP commands")
+func registerFlags(fs *flag.FlagSet) {
+ fs.BoolVar(&Native, "native", Native, "whether the test is running natively")
fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout")
fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
- fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets")
- fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets")
- fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets")
- fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets")
- fs.StringVar(&LocalDevice, "local_device", LocalDevice, "local device to inject traffic")
- fs.StringVar(&RemoteDevice, "remote_device", RemoteDevice, "remote device on the DUT")
- fs.BoolVar(&Native, "native", Native, "whether the test is running natively")
- fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets")
+ fs.StringVar(&dutTestNetsJSON, "dut_test_nets_json", dutTestNetsJSON, "path to the dut test nets json file")
}
-// genPseudoFlags populates flag-like global config based on real flags.
-//
-// genPseudoFlags must only be called after flag.Parse.
-func genPseudoFlags() error {
- out, err := exec.Command("ip", "addr", "show").CombinedOutput()
- if err != nil {
- return fmt.Errorf("listing devices: %q: %w", string(out), err)
- }
- devs, err := netdevs.ParseDevices(string(out))
- if err != nil {
- return fmt.Errorf("parsing devices: %w", err)
+// Initialize initializes the testbench, it parse the flags and sets up the
+// pool of test networks for testbench's later use.
+func Initialize(fs *flag.FlagSet) {
+ registerFlags(fs)
+ flag.Parse()
+ if err := loadDUTTestNets(); err != nil {
+ panic(err)
}
+}
- _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(LocalIPv4), devs)
- if err != nil {
- return fmt.Errorf("can't find deviceInfo: %w", err)
+// loadDUTTestNets loads available DUT test networks from the json file, it
+// must be called after flag.Parse().
+func loadDUTTestNets() error {
+ var parsedTestNets []DUTTestNet
+ if err := json.Unmarshal([]byte(dutTestNetsJSON), &parsedTestNets); err != nil {
+ return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
-
- LocalMAC = deviceInfo.MAC.String()
- LocalIPv6 = deviceInfo.IPv6Addr.String()
- LocalInterfaceID = deviceInfo.ID
-
- if deviceInfo.IPv4Net != nil {
- IPv4PrefixLength, _ = deviceInfo.IPv4Net.Mask.Size()
- } else {
- IPv4PrefixLength, _ = net.ParseIP(LocalIPv4).DefaultMask().Size()
+ if got, want := len(parsedTestNets), 1; got < want {
+ return fmt.Errorf("got %d DUTs, the test requires at least %d DUTs", got, want)
+ }
+ // Using a buffered channel as semaphore
+ dutTestNets = make(chan *DUTTestNet, len(parsedTestNets))
+ for i := range parsedTestNets {
+ parsedTestNets[i].LocalIPv4 = parsedTestNets[i].LocalIPv4.To4()
+ parsedTestNets[i].RemoteIPv4 = parsedTestNets[i].RemoteIPv4.To4()
+ dutTestNets <- &parsedTestNets[i]
}
-
return nil
}
@@ -132,3 +129,15 @@ func GenerateRandomPayload(t *testing.T, n int) []byte {
}
return buf
}
+
+// GetDUTTestNet gets a usable DUTTestNet, the function will block until any
+// becomes available.
+func GetDUTTestNet() *DUTTestNet {
+ return <-dutTestNets
+}
+
+// Release releases the DUTTestNet back to the pool so that some other test
+// can use.
+func (n *DUTTestNet) Release() {
+ dutTestNets <- n
+}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 33bd070c1..b1b3c578b 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -366,9 +366,29 @@ packetimpact_testbench(
],
)
+packetimpact_testbench(
+ name = "tcp_zero_receive_window",
+ srcs = ["tcp_zero_receive_window_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
validate_all_tests()
[packetimpact_go_test(
name = t.name,
expect_netstack_failure = hasattr(t, "expect_netstack_failure"),
+ num_duts = t.num_duts if hasattr(t, "num_duts") else 1,
) for t in ALL_TESTS]
+
+test_suite(
+ name = "all_tests",
+ tags = [
+ "manual",
+ "packetimpact",
+ ],
+ tests = existing_rules(),
+)
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
index a61054c2c..11f0fcd1e 100644
--- a/test/packetimpact/tests/fin_wait2_timeout_test.go
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -25,7 +25,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestFinWait2Timeout(t *testing.T) {
@@ -38,10 +38,9 @@ func TestFinWait2Timeout(t *testing.T) {
} {
t.Run(tt.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go
index 2d59d552d..40d7a491d 100644
--- a/test/packetimpact/tests/icmpv6_param_problem_test.go
+++ b/test/packetimpact/tests/icmpv6_param_problem_test.go
@@ -25,15 +25,14 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT
// should respond with an ICMPv6 Parameter Problem message.
func TestICMPv6ParamProblemTest(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
defer conn.Close(t)
ipv6 := testbench.IPv6{
// 254 is reserved and used for experimentation and testing. This should
diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
index 40f899065..d2203082d 100644
--- a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
+++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
@@ -27,17 +27,17 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
type fragmentInfo struct {
offset uint16
size uint16
more uint8
+ id uint16
}
func TestIPv4FragmentReassembly(t *testing.T) {
- const fragmentID = 42
icmpv4ProtoNum := uint8(header.ICMPv4ProtocolNumber)
tests := []struct {
@@ -45,32 +45,78 @@ func TestIPv4FragmentReassembly(t *testing.T) {
ipPayloadLen int
fragments []fragmentInfo
expectReply bool
+ skip bool
+ skipReason string
}{
{
description: "basic reassembly",
- ipPayloadLen: 2000,
+ ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments},
- {offset: 1000, size: 1000, more: 0},
+ {offset: 0, size: 1000, id: 5, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 5, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 5, more: 0},
},
expectReply: true,
},
{
description: "out of order fragments",
- ipPayloadLen: 2000,
+ ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 1000, size: 1000, more: 0},
- {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 6, more: 0},
+ {offset: 0, size: 1000, id: 6, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 6, more: header.IPv4FlagMoreFragments},
},
expectReply: true,
},
+ {
+ description: "duplicated fragments",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 7, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 7, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 7, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 7, more: 0},
+ },
+ expectReply: true,
+ skip: true,
+ skipReason: "gvisor.dev/issues/4971",
+ },
+ {
+ description: "fragment subset",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 8, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 8, more: header.IPv4FlagMoreFragments},
+ {offset: 512, size: 256, id: 8, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 8, more: 0},
+ },
+ expectReply: true,
+ skip: true,
+ skipReason: "gvisor.dev/issues/4971",
+ },
+ {
+ description: "fragment overlap",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 9, more: header.IPv4FlagMoreFragments},
+ {offset: 1512, size: 1000, id: 9, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 9, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 9, more: 0},
+ },
+ expectReply: false,
+ skip: true,
+ skipReason: "gvisor.dev/issues/4971",
+ },
}
for _, test := range tests {
+ if test.skip {
+ t.Skip("%s test skipped: %s", test.description, test.skipReason)
+ continue
+ }
t.Run(test.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- conn := testbench.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{})
+ conn := dut.Net.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{})
defer conn.Close(t)
data := make([]byte, test.ipPayloadLen)
@@ -96,7 +142,7 @@ func TestIPv4FragmentReassembly(t *testing.T) {
Protocol: &icmpv4ProtoNum,
FragmentOffset: testbench.Uint16(fragment.offset),
Flags: testbench.Uint8(fragment.more),
- ID: testbench.Uint16(fragmentID),
+ ID: testbench.Uint16(fragment.id),
},
&testbench.Payload{
Bytes: data[fragment.offset:][:fragment.size],
@@ -115,7 +161,7 @@ func TestIPv4FragmentReassembly(t *testing.T) {
}, time.Second)
if err != nil {
// Either an unexpected frame was received, or none at all.
- if bytesReceived < test.ipPayloadLen {
+ if test.expectReply && bytesReceived < test.ipPayloadLen {
t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err)
}
break
diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
index 7f7a768d3..a63b41366 100644
--- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go
+++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
@@ -28,7 +28,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func recvTCPSegment(t *testing.T, conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) {
@@ -67,12 +67,10 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
-
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go b/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go
index e058fb0d8..a37867e85 100644
--- a/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go
+++ b/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go
@@ -16,7 +16,6 @@ package ipv6_fragment_icmp_error_test
import (
"flag"
- "net"
"testing"
"time"
@@ -35,10 +34,10 @@ const (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
-func fragmentedICMPEchoRequest(t *testing.T, conn *testbench.Connection, firstPayloadLength uint16, payload []byte, secondFragmentOffset uint16) ([]testbench.Layers, [][]byte) {
+func fragmentedICMPEchoRequest(t *testing.T, n *testbench.DUTTestNet, conn *testbench.Connection, firstPayloadLength uint16, payload []byte, secondFragmentOffset uint16) ([]testbench.Layers, [][]byte) {
t.Helper()
icmpv6Header := header.ICMPv6(make([]byte, header.ICMPv6EchoMinimumSize))
@@ -48,8 +47,8 @@ func fragmentedICMPEchoRequest(t *testing.T, conn *testbench.Connection, firstPa
icmpv6Header.SetSequence(0)
cksum := header.ICMPv6Checksum(
icmpv6Header,
- tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()),
- tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()),
+ tcpip.Address(n.LocalIPv6),
+ tcpip.Address(n.RemoteIPv6),
buffer.NewVectorisedView(len(payload), []buffer.View{payload}),
)
icmpv6Header.SetChecksum(cksum)
@@ -120,13 +119,13 @@ func TestIPv6ICMPEchoRequestFragmentReassembly(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ ipv6Conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
conn := (*testbench.Connection)(&ipv6Conn)
defer ipv6Conn.Close(t)
- fragments, _ := fragmentedICMPEchoRequest(t, conn, test.firstPayloadLength, test.payload, test.secondFragmentOffset)
+ fragments, _ := fragmentedICMPEchoRequest(t, dut.Net, conn, test.firstPayloadLength, test.payload, test.secondFragmentOffset)
for _, i := range test.sendFrameOrder {
conn.SendFrame(t, fragments[i-1])
@@ -222,13 +221,13 @@ func TestIPv6FragmentReassemblyTimeout(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ ipv6Conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
conn := (*testbench.Connection)(&ipv6Conn)
defer ipv6Conn.Close(t)
- fragments, ipv6Bytes := fragmentedICMPEchoRequest(t, conn, test.firstPayloadLength, test.payload, test.secondFragmentOffset)
+ fragments, ipv6Bytes := fragmentedICMPEchoRequest(t, dut.Net, conn, test.firstPayloadLength, test.payload, test.secondFragmentOffset)
for _, i := range test.sendFrameOrder {
conn.SendFrame(t, fragments[i-1])
@@ -318,13 +317,13 @@ func TestIPv6FragmentParamProblem(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ ipv6Conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
conn := (*testbench.Connection)(&ipv6Conn)
defer ipv6Conn.Close(t)
- fragments, ipv6Bytes := fragmentedICMPEchoRequest(t, conn, test.firstPayloadLength, test.payload, test.secondFragmentOffset)
+ fragments, ipv6Bytes := fragmentedICMPEchoRequest(t, dut.Net, conn, test.firstPayloadLength, test.payload, test.secondFragmentOffset)
for _, i := range test.sendFrameOrder {
conn.SendFrame(t, fragments[i-1])
diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
index eb56a53f7..dd98ee7a1 100644
--- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
+++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
@@ -17,7 +17,6 @@ package ipv6_fragment_reassembly_test
import (
"flag"
"math/rand"
- "net"
"testing"
"time"
@@ -29,17 +28,17 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
type fragmentInfo struct {
offset uint16
size uint16
more bool
+ id uint32
}
func TestIPv6FragmentReassembly(t *testing.T) {
- const fragmentID = 42
icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)
tests := []struct {
@@ -50,10 +49,11 @@ func TestIPv6FragmentReassembly(t *testing.T) {
}{
{
description: "basic reassembly",
- ipPayloadLen: 1500,
+ ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 0, size: 760, more: true},
- {offset: 760, size: 740, more: false},
+ {offset: 0, size: 1000, id: 100, more: true},
+ {offset: 1000, size: 1000, id: 100, more: true},
+ {offset: 2000, size: 1000, id: 100, more: false},
},
expectReply: true,
},
@@ -61,23 +61,55 @@ func TestIPv6FragmentReassembly(t *testing.T) {
description: "out of order fragments",
ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 0, size: 1024, more: true},
- {offset: 2048, size: 952, more: false},
- {offset: 1024, size: 1024, more: true},
+ {offset: 0, size: 1000, id: 101, more: true},
+ {offset: 2000, size: 1000, id: 101, more: false},
+ {offset: 1000, size: 1000, id: 101, more: true},
+ },
+ expectReply: true,
+ },
+ {
+ description: "duplicated fragments",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 102, more: true},
+ {offset: 1000, size: 1000, id: 102, more: true},
+ {offset: 1000, size: 1000, id: 102, more: true},
+ {offset: 2000, size: 1000, id: 102, more: false},
+ },
+ expectReply: true,
+ },
+ {
+ description: "fragment subset",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 103, more: true},
+ {offset: 1000, size: 1000, id: 103, more: true},
+ {offset: 512, size: 256, id: 103, more: true},
+ {offset: 2000, size: 1000, id: 103, more: false},
},
expectReply: true,
},
+ {
+ description: "fragment overlap",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 104, more: true},
+ {offset: 1512, size: 1000, id: 104, more: true},
+ {offset: 1000, size: 1000, id: 104, more: true},
+ {offset: 2000, size: 1000, id: 104, more: false},
+ },
+ expectReply: false,
+ },
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
defer conn.Close(t)
- lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16())
- rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16())
+ lIP := tcpip.Address(dut.Net.LocalIPv6)
+ rIP := tcpip.Address(dut.Net.RemoteIPv6)
data := make([]byte, test.ipPayloadLen)
icmp := header.ICMPv6(data[:header.ICMPv6HeaderSize])
@@ -103,7 +135,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
NextHeader: &icmpv6ProtoNum,
FragmentOffset: testbench.Uint16(fragment.offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
MoreFragments: testbench.Bool(fragment.more),
- Identification: testbench.Uint32(fragmentID),
+ Identification: testbench.Uint32(fragment.id),
},
&testbench.Payload{
Bytes: data[fragment.offset:][:fragment.size],
@@ -120,7 +152,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
}, time.Second)
if err != nil {
// Either an unexpected frame was received, or none at all.
- if bytesReceived < test.ipPayloadLen {
+ if test.expectReply && bytesReceived < test.ipPayloadLen {
t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err)
}
break
diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
index e79d74476..cb5396417 100644
--- a/test/packetimpact/tests/ipv6_unknown_options_action_test.go
+++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
@@ -27,7 +27,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func mkHopByHopOptionsExtHdr(optType byte) testbench.Layer {
@@ -141,8 +141,7 @@ func TestIPv6UnknownOptionAction(t *testing.T) {
} {
t.Run(tt.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ ipv6Conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
conn := (*testbench.Connection)(&ipv6Conn)
defer ipv6Conn.Close(t)
diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go
index 8feea4a82..a7ba5035e 100644
--- a/test/packetimpact/tests/tcp_cork_mss_test.go
+++ b/test/packetimpact/tests/tcp_cork_mss_test.go
@@ -25,16 +25,15 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestTCPCorkMSS tests for segment coalesce and split as per MSS.
func TestTCPCorkMSS(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
const mss = uint32(header.TCPDefaultMSS)
diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go
index 22937d92f..5d1266f3c 100644
--- a/test/packetimpact/tests/tcp_handshake_window_size_test.go
+++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go
@@ -25,17 +25,16 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestTCPHandshakeWindowSize tests if the stack is honoring the window size
// communicated during handshake.
func TestTCPHandshakeWindowSize(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
// Start handshake with zero window size.
diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go
index b9a0409aa..bc4b64388 100644
--- a/test/packetimpact/tests/tcp_linger_test.go
+++ b/test/packetimpact/tests/tcp_linger_test.go
@@ -27,12 +27,12 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func createSocket(t *testing.T, dut testbench.DUT) (int32, int32, testbench.TCPIPv4) {
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
conn.Connect(t)
acceptFD, _ := dut.Accept(t, listenFD)
return acceptFD, listenFD, conn
@@ -41,7 +41,6 @@ func createSocket(t *testing.T, dut testbench.DUT) (int32, int32, testbench.TCPI
func closeAll(t *testing.T, dut testbench.DUT, listenFD int32, conn testbench.TCPIPv4) {
conn.Close(t)
dut.Close(t, listenFD)
- dut.TearDown()
}
// lingerDuration is the timeout value used with SO_LINGER socket option.
@@ -266,5 +265,4 @@ func TestTCPLingerNonEstablished(t *testing.T) {
if diff > lingerDuration {
t.Errorf("expected close to return within %s, but returned after %s", lingerDuration, diff)
}
- dut.TearDown()
}
diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go
index 8a1fe1279..6cd6d2edf 100644
--- a/test/packetimpact/tests/tcp_network_unreachable_test.go
+++ b/test/packetimpact/tests/tcp_network_unreachable_test.go
@@ -17,7 +17,6 @@ package tcp_synsent_reset_test
import (
"context"
"flag"
- "net"
"syscall"
"testing"
"time"
@@ -28,7 +27,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestTCPSynSentUnreachable verifies that TCP connections fail immediately when
@@ -37,17 +36,16 @@ func init() {
func TestTCPSynSentUnreachable(t *testing.T) {
// Create the DUT and connection.
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, dut.Net.RemoteIPv4)
port := uint16(9001)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port})
defer conn.Close(t)
// Bring the DUT to SYN-SENT state with a non-blocking connect.
ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout)
defer cancel()
sa := unix.SockaddrInet4{Port: int(port)}
- copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4())
+ copy(sa.Addr[:], dut.Net.LocalIPv4)
if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err)
}
@@ -91,9 +89,8 @@ func TestTCPSynSentUnreachable(t *testing.T) {
func TestTCPSynSentUnreachable6(t *testing.T) {
// Create the DUT and connection.
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6))
- conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort})
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, dut.Net.RemoteIPv6)
+ conn := dut.Net.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort})
defer conn.Close(t)
// Bring the DUT to SYN-SENT state with a non-blocking connect.
@@ -101,9 +98,9 @@ func TestTCPSynSentUnreachable6(t *testing.T) {
defer cancel()
sa := unix.SockaddrInet6{
Port: int(conn.SrcPort()),
- ZoneId: uint32(testbench.RemoteInterfaceID),
+ ZoneId: dut.Net.RemoteDevID,
}
- copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16())
+ copy(sa.Addr[:], dut.Net.LocalIPv6)
if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err)
}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
index 82b7a85ff..f0af5352d 100644
--- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -25,14 +25,13 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestTcpNoAcceptCloseReset(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
conn.Connect(t)
defer conn.Close(t)
dut.Close(t, listenFd)
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
index 08f759f7c..1b041932a 100644
--- a/test/packetimpact/tests/tcp_outside_the_window_test.go
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -27,7 +27,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive
@@ -62,10 +62,9 @@ func TestTCPOutsideTheWindow(t *testing.T) {
} {
t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
acceptFD, _ := dut.Accept(t, listenFD)
diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go
index 37f3b56dd..24d9ef4ec 100644
--- a/test/packetimpact/tests/tcp_paws_mechanism_test.go
+++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go
@@ -26,15 +26,14 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestPAWSMechanism(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
options := make([]byte, header.TCPOptionTSLength)
diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
index d9f3ea0f2..646c93216 100644
--- a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
+++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
@@ -20,7 +20,6 @@ import (
"encoding/hex"
"errors"
"flag"
- "net"
"sync"
"syscall"
"testing"
@@ -32,7 +31,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestQueueReceiveInSynSent tests receive behavior when the TCP state
@@ -50,10 +49,9 @@ func TestQueueReceiveInSynSent(t *testing.T) {
} {
t.Run(tt.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, dut.Net.RemoteIPv4)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
sampleData := []byte("Sample Data")
diff --git a/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go
index 0ec8fd748..29e51cae3 100644
--- a/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go
+++ b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go
@@ -18,7 +18,6 @@ import (
"context"
"errors"
"flag"
- "net"
"sync"
"syscall"
"testing"
@@ -30,7 +29,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestQueueSendInSynSent tests send behavior when the TCP state
@@ -48,10 +47,9 @@ func TestQueueSendInSynSent(t *testing.T) {
} {
t.Run(tt.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, dut.Net.RemoteIPv4)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
sampleData := []byte("Sample Data")
diff --git a/test/packetimpact/tests/tcp_rcv_buf_space_test.go b/test/packetimpact/tests/tcp_rcv_buf_space_test.go
index cfbba1e8e..d6ad5cda6 100644
--- a/test/packetimpact/tests/tcp_rcv_buf_space_test.go
+++ b/test/packetimpact/tests/tcp_rcv_buf_space_test.go
@@ -26,7 +26,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestReduceRecvBuf tests that a packet within window is still dropped
@@ -34,10 +34,9 @@ func init() {
// segment.
func TestReduceRecvBuf(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go
index b4aeaab57..ca352dbc7 100644
--- a/test/packetimpact/tests/tcp_reordering_test.go
+++ b/test/packetimpact/tests/tcp_reordering_test.go
@@ -22,19 +22,18 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestReorderingWindow(t *testing.T) {
- dut := tb.NewDUT(t)
- defer dut.TearDown()
+ dut := testbench.NewDUT(t)
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
// Enable SACK.
@@ -54,13 +53,13 @@ func TestReorderingWindow(t *testing.T) {
acceptFd, _ := dut.Accept(t, listenFd)
defer dut.Close(t, acceptFd)
- if tb.Native {
+ if testbench.Native {
// Linux has changed its handling of reordering, force the old behavior.
dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno"))
}
pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG)
- if !tb.Native {
+ if !testbench.Native {
// netstack does not impliment TCP_MAXSEG correctly. Fake it
// here. Netstack uses the max SACK size which is 32. The MSS
// option is 8 bytes, making the total 36 bytes.
@@ -75,14 +74,14 @@ func TestReorderingWindow(t *testing.T) {
for i, sn := 0, seqNum1; i < numPkts; i++ {
dut.Send(t, acceptFd, payload, 0)
- gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ gotOne, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(sn))}, time.Second)
sn.UpdateForward(seqnum.Size(len(payload)))
if err != nil {
- t.Errorf("Expect #%d: %s", i+1, err)
+ t.Fatalf("Expect #%d: %s", i+1, err)
continue
}
if gotOne == nil {
- t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ t.Fatalf("#%d: expected a packet within a second but got none", i+1)
}
}
@@ -97,13 +96,13 @@ func TestReorderingWindow(t *testing.T) {
seqNum1.Add(seqnum.Size(len(payload))),
seqNum1.Add(seqnum.Size(4 * len(payload))),
}}, sackBlock[sbOff:])
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
// ACK first packet.
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1) + uint32(len(payload)))})
// Check for retransmit.
- gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second)
+ gotOne, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second)
if err != nil {
t.Error("Expect for retransmit:", err)
}
@@ -123,29 +122,29 @@ func TestReorderingWindow(t *testing.T) {
seqNum1.Add(seqnum.Size(4 * len(payload))),
}}, dsackBlock[dsbOff:])
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]})
// Send half of the original window of packets, checking that we
// received each.
for i, sn := 0, seqNum2; i < numPkts/2; i++ {
dut.Send(t, acceptFd, payload, 0)
- gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ gotOne, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(sn))}, time.Second)
sn.UpdateForward(seqnum.Size(len(payload)))
if err != nil {
- t.Errorf("Expect #%d: %s", i+1, err)
+ t.Fatalf("Expect #%d: %s", i+1, err)
continue
}
if gotOne == nil {
- t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ t.Fatalf("#%d: expected a packet within a second but got none", i+1)
}
}
- if !tb.Native {
+ if !testbench.Native {
// The window should now be halved, so we should receive any
// more, even if we send them.
dut.Send(t, acceptFd, payload, 0)
- if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ if got, err := conn.Expect(t, testbench.TCP{}, 100*time.Millisecond); got != nil || err == nil {
t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
}
return
@@ -155,20 +154,20 @@ func TestReorderingWindow(t *testing.T) {
for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ {
dut.Send(t, acceptFd, payload, 0)
- gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ gotOne, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(sn))}, time.Second)
sn.UpdateForward(seqnum.Size(len(payload)))
if err != nil {
- t.Errorf("Expect #%d: %s", i+1, err)
+ t.Fatalf("Expect #%d: %s", i+1, err)
continue
}
if gotOne == nil {
- t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ t.Fatalf("#%d: expected a packet within a second but got none", i+1)
}
}
// The window should now be full.
dut.Send(t, acceptFd, payload, 0)
- if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ if got, err := conn.Expect(t, testbench.TCP{}, 100*time.Millisecond); got != nil || err == nil {
t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
}
}
diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go
index 072014ff8..27e9641b1 100644
--- a/test/packetimpact/tests/tcp_retransmits_test.go
+++ b/test/packetimpact/tests/tcp_retransmits_test.go
@@ -25,17 +25,16 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestRetransmits tests retransmits occur at exponentially increasing
// time intervals.
func TestRetransmits(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
index f91b06ba1..418393796 100644
--- a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
+++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
@@ -16,7 +16,6 @@ package tcp_send_window_sizes_piggyback_test
import (
"flag"
- "fmt"
"testing"
"time"
@@ -26,7 +25,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestSendWindowSizesPiggyback tests cases where segment sizes are close to
@@ -58,13 +57,12 @@ func TestSendWindowSizesPiggyback(t *testing.T) {
// greater than available sender window.
{"WindowGreaterThanSegment", segmentSize + 1, sampleData, sampleData, true /* enqueue */},
} {
- t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) {
+ t.Run(tt.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
index 57d034dd1..c5bbd29ee 100644
--- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go
+++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
@@ -25,16 +25,15 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestTCPSynRcvdReset tests transition from SYN-RCVD to CLOSED.
func TestTCPSynRcvdReset(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
// Expect dut connection to have transitioned to SYN-RCVD state.
diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go
index eac8eb19d..2c8bb101b 100644
--- a/test/packetimpact/tests/tcp_synsent_reset_test.go
+++ b/test/packetimpact/tests/tcp_synsent_reset_test.go
@@ -16,34 +16,33 @@ package tcp_synsent_reset_test
import (
"flag"
- "net"
"testing"
"time"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// dutSynSentState sets up the dut connection in SYN-SENT state.
-func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) {
+func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16, uint16) {
t.Helper()
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
- clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4))
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, dut.Net.RemoteIPv4)
port := uint16(9001)
- conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port})
sa := unix.SockaddrInet4{Port: int(port)}
- copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4())
+ copy(sa.Addr[:], dut.Net.LocalIPv4)
// Bring the dut to SYN-SENT state with a non-blocking connect.
dut.Connect(t, clientFD, &sa)
- if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN\n")
}
@@ -52,14 +51,13 @@ func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) {
// TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition.
func TestTCPSynSentReset(t *testing.T) {
- dut, conn, _, _ := dutSynSentState(t)
+ _, conn, _, _ := dutSynSentState(t)
defer conn.Close(t)
- defer dut.TearDown()
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
// Expect the connection to have closed.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
@@ -68,23 +66,22 @@ func TestTCPSynSentReset(t *testing.T) {
// transitions.
func TestTCPSynSentRcvdReset(t *testing.T) {
dut, c, remotePort, clientPort := dutSynSentState(t)
- defer dut.TearDown()
defer c.Close(t)
- conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{SrcPort: &remotePort, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &remotePort})
defer conn.Close(t)
// Initiate new SYN connection with the same port pair
// (simultaneous open case), expect the dut connection to move to
// SYN-RCVD state
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)})
- if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK %s\n", err)
}
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
// Expect the connection to have transitioned SYN-RCVD to CLOSED.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
diff --git a/test/packetimpact/tests/tcp_timewait_reset_test.go b/test/packetimpact/tests/tcp_timewait_reset_test.go
index 2f76a6531..d1d2fb83d 100644
--- a/test/packetimpact/tests/tcp_timewait_reset_test.go
+++ b/test/packetimpact/tests/tcp_timewait_reset_test.go
@@ -25,16 +25,15 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestTimeWaitReset tests handling of RST when in TIME_WAIT state.
func TestTimeWaitReset(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
index d078bbf15..ea962c818 100644
--- a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
+++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
@@ -28,7 +28,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestEstablishedUnaccSeqAck(t *testing.T) {
@@ -48,10 +48,9 @@ func TestEstablishedUnaccSeqAck(t *testing.T) {
} {
t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
@@ -102,10 +101,9 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) {
} {
t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
@@ -164,10 +162,9 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) {
} {
t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
index 551dc78e7..b16e65366 100644
--- a/test/packetimpact/tests/tcp_user_timeout_test.go
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -25,7 +25,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) {
@@ -64,10 +64,9 @@ func TestTCPUserTimeout(t *testing.T) {
t.Run(tt.description+ttf.description, func(t *testing.T) {
// Create a socket, listen, TCP handshake, and accept.
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFD)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
acceptFD, _ := dut.Accept(t, listenFD)
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
index 5b001fbec..093484721 100644
--- a/test/packetimpact/tests/tcp_window_shrink_test.go
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -25,15 +25,14 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestWindowShrink(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_zero_receive_window_test.go b/test/packetimpact/tests/tcp_zero_receive_window_test.go
new file mode 100644
index 000000000..d06690705
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_receive_window_test.go
@@ -0,0 +1,125 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_receive_window_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestZeroReceiveWindow tests if the DUT sends a zero receive window eventually.
+func TestZeroReceiveWindow(t *testing.T) {
+ for _, payloadLen := range []int{64, 512, 1024} {
+ t.Run(fmt.Sprintf("TestZeroReceiveWindow_with_%dbytes_payload", payloadLen), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)}
+ // Expect the DUT to eventually advertise zero receive window.
+ // The test would timeout otherwise.
+ for readOnce := false; ; {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+ // Read once to trigger the subsequent window update from the
+ // DUT to grow the right edge of the receive window from what
+ // was advertised in the SYN-ACK. This ensures that we test
+ // for the full default buffer size (1MB on gVisor at the time
+ // of writing this comment), thus testing for cases when the
+ // scaled receive window size ends up > 65535 (0xffff).
+ if !readOnce {
+ if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen {
+ t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen)
+ }
+ readOnce = true
+ }
+ windowSize := *gotTCP.WindowSize
+ t.Logf("got window size = %d", windowSize)
+ if windowSize == 0 {
+ break
+ }
+ }
+ })
+ }
+}
+
+// TestNonZeroReceiveWindow tests for the DUT to never send a zero receive
+// window when the data is being read from the socket buffer.
+func TestNonZeroReceiveWindow(t *testing.T) {
+ for _, payloadLen := range []int{64, 512, 1024} {
+ t.Run(fmt.Sprintf("TestZeroReceiveWindow_with_%dbytes_payload", payloadLen), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)}
+ var rcvWindow uint16
+ initRcv := false
+ // This loop keeps a running rcvWindow value from the initial ACK for the data
+ // we sent. Once we have received ACKs with non-zero receive windows, we break
+ // the loop.
+ for {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+ if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen {
+ t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen)
+ }
+ if *gotTCP.WindowSize == 0 {
+ t.Fatalf("expected non-zero receive window.")
+ }
+ if !initRcv {
+ rcvWindow = uint16(*gotTCP.WindowSize)
+ initRcv = true
+ }
+ if rcvWindow <= uint16(payloadLen) {
+ break
+ }
+ rcvWindow -= uint16(payloadLen)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
index da93267d6..1ab9ee1b2 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -25,17 +25,16 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestZeroWindowProbeRetransmit tests retransmits of zero window probes
// to be sent at exponentially inreasing time intervals.
func TestZeroWindowProbeRetransmit(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go
index 44cac42f8..650a569cc 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go
@@ -25,17 +25,16 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestZeroWindowProbe tests few cases of zero window probing over the
// same connection.
func TestZeroWindowProbe(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
index 09a1c653f..079fea68c 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
@@ -25,17 +25,16 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
// TestZeroWindowProbeUserTimeout sanity tests user timeout when we are
// retransmitting zero window probes.
func TestZeroWindowProbeUserTimeout(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
defer conn.Close(t)
conn.Connect(t)
diff --git a/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go
index 17f32ef65..f4ae00a81 100644
--- a/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go
+++ b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go
@@ -26,21 +26,20 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestAnyRecvUnicastUDP(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
defer dut.Close(t, boundFD)
- conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close(t)
payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */)
conn.SendIP(
t,
- testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP(testbench.RemoteIPv4).To4()))},
+ testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(dut.Net.RemoteIPv4))},
testbench.UDP{},
&testbench.Payload{Bytes: payload},
)
diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
index 3d2791a6e..52c6f9d91 100644
--- a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
+++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
@@ -30,16 +30,15 @@ import (
var oneSecond = unix.Timeval{Sec: 1, Usec: 0}
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4))
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, dut.Net.RemoteIPv4)
defer dut.Close(t, remoteFD)
dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
- conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close(t)
for _, mcastAddr := range []net.IP{
@@ -66,11 +65,10 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) {
func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6))
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, dut.Net.RemoteIPv6)
defer dut.Close(t, remoteFD)
dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
- conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ conn := dut.Net.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close(t)
for _, mcastAddr := range []net.IP{
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
index df35d16c8..cd4523e88 100644
--- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -30,7 +30,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
type connectionMode bool
@@ -229,7 +229,6 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
} {
t.Run(fmt.Sprintf("%s/%s/%s", connect, icmpErr, errDetect.name), func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
defer dut.Close(t, remoteFD)
@@ -239,7 +238,7 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
defer dut.Close(t, cleanFD)
- conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close(t)
if connect {
@@ -261,7 +260,7 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
// involved in the generation of the ICMP error. As such,
// interactions between it and the the DUT should be independent of
// the ICMP error at least at the port level.
- connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ connClean := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer connClean.Close(t)
errDetectConn = &connClean
@@ -283,7 +282,6 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
t.Run(fmt.Sprintf("%s/%s", connect, icmpErr), func(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
defer dut.Close(t, remoteFD)
@@ -293,7 +291,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
defer dut.Close(t, cleanFD)
- conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close(t)
if connect {
diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
index 526173969..b29c07825 100644
--- a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
+++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
@@ -29,12 +29,12 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
func TestUDPRecvMcastBcast(t *testing.T) {
- subnetBcastAddr := broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32))
-
+ dut := testbench.NewDUT(t)
+ subnetBcastAddr := broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32))
for _, v := range []struct {
bound, to net.IP
}{
@@ -43,17 +43,22 @@ func TestUDPRecvMcastBcast(t *testing.T) {
{bound: net.IPv4zero, to: net.IPv4allsys},
{bound: subnetBcastAddr, to: subnetBcastAddr},
- {bound: subnetBcastAddr, to: net.IPv4bcast},
+
+ // FIXME(gvisor.dev/issue/4896): Previously by the time subnetBcastAddr is
+ // created, IPv4PrefixLength is still 0 because genPseudoFlags is not called
+ // yet, it was only called in NewDUT, so the test didn't do what the author
+ // original intended to and becomes failing because we process all flags at
+ // the very beginning.
+ //
+ // {bound: subnetBcastAddr, to: net.IPv4bcast},
{bound: net.IPv4bcast, to: net.IPv4bcast},
{bound: net.IPv4allsys, to: net.IPv4allsys},
} {
t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) {
- dut := testbench.NewDUT(t)
- defer dut.TearDown()
boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound)
defer dut.Close(t, boundFD)
- conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close(t)
payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */)
@@ -73,15 +78,14 @@ func TestUDPRecvMcastBcast(t *testing.T) {
func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
- boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4))
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, dut.Net.RemoteIPv4)
dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0})
defer dut.Close(t, boundFD)
- conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close(t)
for _, to := range []net.IP{
- broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)),
+ broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32)),
net.IPv4(255, 255, 255, 255),
net.IPv4(224, 0, 0, 1),
} {
@@ -102,9 +106,10 @@ func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) {
}
func broadcastAddr(ip net.IP, mask net.IPMask) net.IP {
+ result := make(net.IP, net.IPv4len)
ip4 := ip.To4()
for i := range ip4 {
- ip4[i] |= ^mask[i]
+ result[i] = ip4[i] | ^mask[i]
}
- return ip4
+ return result
}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
index 91b967400..7ee2c8014 100644
--- a/test/packetimpact/tests/udp_send_recv_dgram_test.go
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -26,7 +26,7 @@ import (
)
func init() {
- testbench.RegisterFlags(flag.CommandLine)
+ testbench.Initialize(flag.CommandLine)
}
type udpConn interface {
@@ -38,7 +38,6 @@ type udpConn interface {
func TestUDP(t *testing.T) {
dut := testbench.NewDUT(t)
- defer dut.TearDown()
for _, isIPv4 := range []bool{true, false} {
ipVersionName := "IPv6"
@@ -46,24 +45,24 @@ func TestUDP(t *testing.T) {
ipVersionName = "IPv4"
}
t.Run(ipVersionName, func(t *testing.T) {
- var addr string
+ var addr net.IP
if isIPv4 {
- addr = testbench.RemoteIPv4
+ addr = dut.Net.RemoteIPv4
} else {
- addr = testbench.RemoteIPv6
+ addr = dut.Net.RemoteIPv6
}
- boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr))
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, addr)
defer dut.Close(t, boundFD)
var conn udpConn
var localAddr unix.Sockaddr
if isIPv4 {
- v4Conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ v4Conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
localAddr = v4Conn.LocalAddr(t)
conn = &v4Conn
} else {
- v6Conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- localAddr = v6Conn.LocalAddr(t)
+ v6Conn := dut.Net.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ localAddr = v6Conn.LocalAddr(t, dut.Net.RemoteDevID)
conn = &v6Conn
}
defer conn.Close(t)
diff --git a/test/perf/BUILD b/test/perf/BUILD
index b763be50e..e25f090ae 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -1,3 +1,4 @@
+load("//tools:defs.bzl", "more_shards")
load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
@@ -37,7 +38,7 @@ syscall_test(
syscall_test(
size = "enormous",
debug = False,
- shard_count = 10,
+ shard_count = more_shards,
tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
)
diff --git a/test/root/BUILD b/test/root/BUILD
index a9130b34f..8d9fff578 100644
--- a/test/root/BUILD
+++ b/test/root/BUILD
@@ -1,5 +1,4 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/vm:defs.bzl", "vm_test")
package(licenses = ["notice"])
@@ -24,12 +23,8 @@ go_test(
],
library = ":root",
tags = [
- # Requires docker and runsc to be configured before the test runs.
- # Also, the test needs to be run as root. Note that below, the
- # root_vm_test relies on the default runtime 'runsc' being installed by
- # the default installer.
- "manual",
"local",
+ "manual",
],
visibility = ["//:sandbox"],
deps = [
@@ -46,10 +41,3 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
-
-vm_test(
- name = "root_vm_test",
- size = "large",
- shard_count = 1,
- targets = [":root_test"],
-)
diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go
index 11ac5cb52..df52dd381 100644
--- a/test/root/crictl_test.go
+++ b/test/root/crictl_test.go
@@ -315,7 +315,7 @@ const (
// v1 is the containerd API v1.
v1 string = "v1"
- // v1 is the containerd API v21.
+ // v2 is the containerd API v2.
v2 string = "v2"
)
@@ -480,7 +480,7 @@ func setup(t *testing.T, version string) (*criutil.Crictl, func(), error) {
}
// Wait for containerd to boot.
- if err := testutil.WaitUntilRead(startupR, "Start streaming server", nil, 10*time.Second); err != nil {
+ if err := testutil.WaitUntilRead(startupR, "Start streaming server", 10*time.Second); err != nil {
t.Fatalf("failed to start containerd: %v", err)
}
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 7618f6a21..829247657 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -12,7 +12,7 @@ def _runner_test_impl(ctx):
" mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
" chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
"fi",
- "exec %s %s %s\n" % (
+ "exec %s %s \"$@\" %s\n" % (
ctx.files.runner[0].short_path,
" ".join(ctx.attr.runner_args),
ctx.files.test[0].short_path,
@@ -52,8 +52,6 @@ _runner_test = rule(
def _syscall_test(
test,
- shard_count,
- size,
platform,
use_tmpfs,
tags,
@@ -63,7 +61,8 @@ def _syscall_test(
overlay = False,
add_uds_tree = False,
vfs2 = False,
- fuse = False):
+ fuse = False,
+ **kwargs):
# Prepend "runsc" to non-native platform names.
full_platform = platform if platform == "native" else "runsc_" + platform
@@ -126,15 +125,12 @@ def _syscall_test(
name = name,
test = test,
runner_args = runner_args,
- size = size,
tags = tags,
- shard_count = shard_count,
+ **kwargs
)
def syscall_test(
test,
- shard_count = 5,
- size = "small",
use_tmpfs = False,
add_overlay = False,
add_uds_tree = False,
@@ -142,18 +138,21 @@ def syscall_test(
vfs2 = True,
fuse = False,
debug = True,
- tags = None):
+ tags = None,
+ **kwargs):
"""syscall_test is a macro that will create targets for all platforms.
Args:
test: the test target.
- shard_count: shards for defined tests.
- size: the defined test size.
use_tmpfs: use tmpfs in the defined tests.
add_overlay: add an overlay test.
add_uds_tree: add a UDS test.
add_hostinet: add a hostinet test.
+ vfs2: enable VFS2 support.
+ fuse: enable FUSE support.
+ debug: enable debug output.
tags: starting test tags.
+ **kwargs: additional test arguments.
"""
if not tags:
tags = []
@@ -173,8 +172,6 @@ def syscall_test(
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
@@ -182,6 +179,7 @@ def syscall_test(
debug = debug,
vfs2 = True,
fuse = fuse,
+ **kwargs
)
if fuse:
# Only generate *_vfs2_fuse target if fuse parameter is enabled.
@@ -189,38 +187,35 @@ def syscall_test(
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = "native",
use_tmpfs = False,
add_uds_tree = add_uds_tree,
tags = list(tags),
debug = debug,
+ **kwargs
)
for (platform, platform_tags) in platforms.items():
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platform_tags + tags,
debug = debug,
+ **kwargs
)
if add_overlay:
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
overlay = True,
+ **kwargs
)
# TODO(gvisor.dev/issue/4407): Remove tags to enable VFS2 overlay tests.
@@ -230,8 +225,6 @@ def syscall_test(
overlay_vfs2_tags.append("notap")
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
@@ -239,38 +232,35 @@ def syscall_test(
debug = debug,
overlay = True,
vfs2 = True,
+ **kwargs
)
if add_hostinet:
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
network = "host",
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
+ **kwargs
)
if not use_tmpfs:
# Also test shared gofer access.
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
file_access = "shared",
+ **kwargs
)
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
@@ -278,4 +268,5 @@ def syscall_test(
debug = debug,
file_access = "shared",
vfs2 = True,
+ **kwargs
)
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index 22b526f59..510ffe013 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "bzl_library")
+load("//tools:defs.bzl", "bzl_library", "more_shards", "most_shards")
load("//test/runtimes:defs.bzl", "runtime_test")
package(licenses = ["notice"])
@@ -7,7 +7,7 @@ runtime_test(
name = "go1.12",
exclude_file = "exclude/go1.12.csv",
lang = "go",
- shard_count = 8,
+ shard_count = more_shards,
)
runtime_test(
@@ -15,28 +15,28 @@ runtime_test(
batch = 100,
exclude_file = "exclude/java11.csv",
lang = "java",
- shard_count = 16,
+ shard_count = most_shards,
)
runtime_test(
name = "nodejs12.4.0",
exclude_file = "exclude/nodejs12.4.0.csv",
lang = "nodejs",
- shard_count = 8,
+ shard_count = most_shards,
)
runtime_test(
name = "php7.3.6",
exclude_file = "exclude/php7.3.6.csv",
lang = "php",
- shard_count = 8,
+ shard_count = more_shards,
)
runtime_test(
name = "python3.7.3",
exclude_file = "exclude/python3.7.3.csv",
lang = "python",
- shard_count = 8,
+ shard_count = more_shards,
)
bzl_library(
diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go
index 64e6e14db..9272137ff 100644
--- a/test/runtimes/runner/lib/lib.go
+++ b/test/runtimes/runner/lib/lib.go
@@ -34,12 +34,7 @@ import (
// RunTests is a helper that is called by main. It exists so that we can run
// defered functions before exiting. It returns an exit code that should be
// passed to os.Exit.
-func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, batchSize int, timeout time.Duration) int {
- if partitionNum <= 0 || totalPartitions <= 0 || partitionNum > totalPartitions {
- fmt.Fprintf(os.Stderr, "invalid partition %d of %d", partitionNum, totalPartitions)
- return 1
- }
-
+func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int {
// TODO(gvisor.dev/issue/1624): Remove those tests from all exclude lists
// that only fail with VFS1.
@@ -63,7 +58,7 @@ func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, ba
// Get a slice of tests to run. This will also start a single Docker
// container that will be used to run each test. The final test will
// stop the Docker container.
- tests, err := getTests(ctx, d, lang, image, partitionNum, totalPartitions, batchSize, timeout, excludes)
+ tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
return 1
@@ -74,7 +69,7 @@ func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, ba
}
// getTests executes all tests as table tests.
-func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, partitionNum, totalPartitions, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) {
+func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) {
// Start the container.
opts := dockerutil.RunOpts{
Image: fmt.Sprintf("runtimes/%s", image),
@@ -90,18 +85,9 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string,
return nil, fmt.Errorf("docker exec failed: %v", err)
}
- // Calculate a subset of tests to run corresponding to the current
- // shard.
+ // Calculate a subset of tests.
tests := strings.Fields(list)
sort.Strings(tests)
-
- partitionSize := len(tests) / totalPartitions
- if partitionNum == totalPartitions {
- tests = tests[(partitionNum-1)*partitionSize:]
- } else {
- tests = tests[(partitionNum-1)*partitionSize : partitionNum*partitionSize]
- }
-
indices, err := testutil.TestIndicesForShard(len(tests))
if err != nil {
return nil, fmt.Errorf("TestsForShard() failed: %v", err)
@@ -122,6 +108,10 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string,
}
tcs = append(tcs, tests[tc])
}
+ if len(tcs) == 0 {
+ // No tests to add to this batch.
+ continue
+ }
itests = append(itests, testing.InternalTest{
Name: strings.Join(tcs, ", "),
F: func(t *testing.T) {
diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go
index 5b3443e36..ec79a22c2 100644
--- a/test/runtimes/runner/main.go
+++ b/test/runtimes/runner/main.go
@@ -25,13 +25,11 @@ import (
)
var (
- lang = flag.String("lang", "", "language runtime to test")
- image = flag.String("image", "", "docker image with runtime tests")
- excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
- partition = flag.Int("partition", 1, "partition number, this is 1-indexed")
- totalPartitions = flag.Int("total_partitions", 1, "total number of partitions")
- batchSize = flag.Int("batch", 50, "number of test cases run in one command")
- timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout")
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
+ batchSize = flag.Int("batch", 50, "number of test cases run in one command")
+ timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout")
)
func main() {
@@ -40,5 +38,5 @@ func main() {
fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
os.Exit(1)
}
- os.Exit(lib.RunTests(*lang, *image, *excludeFile, *partition, *totalPartitions, *batchSize, *timeout))
+ os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout))
}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index b5a4ef4df..a5b9233f7 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -1,3 +1,4 @@
+load("//tools:defs.bzl", "more_shards", "most_shards")
load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
@@ -12,7 +13,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:accept_bind_test",
)
@@ -32,7 +33,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:alarm_test",
)
@@ -66,7 +67,7 @@ syscall_test(
size = "large",
# Produce too many logs in the debug mode.
debug = False,
- shard_count = 50,
+ shard_count = most_shards,
# Takes too long for TSAN. Since this is kind of a stress test that doesn't
# involve much concurrency, TSAN's usefulness here is limited anyway.
tags = ["nogotsan"],
@@ -211,7 +212,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:futex_test",
)
@@ -258,7 +259,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:itimer_test",
)
@@ -313,7 +314,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:mmap_test",
)
@@ -347,6 +348,7 @@ syscall_test(
syscall_test(
add_overlay = True,
+ shard_count = more_shards,
test = "//test/syscalls/linux:open_test",
)
@@ -376,7 +378,7 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:pipe_test",
)
@@ -448,7 +450,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:pty_test",
)
@@ -475,6 +477,7 @@ syscall_test(
)
syscall_test(
+ shard_count = more_shards,
test = "//test/syscalls/linux:raw_socket_test",
)
@@ -490,7 +493,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:readv_socket_test",
)
@@ -539,7 +542,7 @@ syscall_test(
)
syscall_test(
- shard_count = 20,
+ shard_count = more_shards,
test = "//test/syscalls/linux:semaphore_test",
)
@@ -594,7 +597,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_abstract_test",
)
@@ -605,7 +608,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_domain_test",
)
@@ -618,19 +621,19 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_filesystem_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_inet_loopback_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
# Takes too long for TSAN. Creates a lot of TCP sockets.
tags = ["nogotsan"],
test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test",
@@ -638,35 +641,38 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test",
)
syscall_test(
size = "medium",
+ add_hostinet = True,
test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_tcp_loopback_test",
)
syscall_test(
size = "medium",
- shard_count = 50,
+ add_hostinet = True,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test",
)
syscall_test(
size = "medium",
+ add_hostinet = True,
test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_udp_loopback_test",
)
@@ -677,6 +683,8 @@ syscall_test(
syscall_test(
size = "medium",
+ add_hostinet = True,
+ shard_count = more_shards,
# Takes too long under gotsan to run.
tags = ["nogotsan"],
test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_nogotsan_test",
@@ -691,6 +699,7 @@ syscall_test(
)
syscall_test(
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_ip_unbound_test",
)
@@ -723,6 +732,7 @@ syscall_test(
)
syscall_test(
+ add_hostinet = True,
test = "//test/syscalls/linux:socket_non_stream_blocking_local_test",
)
@@ -753,7 +763,7 @@ syscall_test(
syscall_test(
# NOTE(b/116636318): Large sendmsg may stall a long time.
size = "enormous",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_dgram_local_test",
)
@@ -765,14 +775,14 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_unix_pair_test",
)
syscall_test(
# NOTE(b/116636318): Large sendmsg may stall a long time.
size = "enormous",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_seqpacket_local_test",
)
@@ -798,13 +808,13 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 10,
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_unbound_seqpacket_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_unix_unbound_stream_test",
)
@@ -858,7 +868,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 10,
+ shard_count = more_shards,
test = "//test/syscalls/linux:tcp_socket_test",
)
@@ -867,6 +877,7 @@ syscall_test(
)
syscall_test(
+ shard_count = more_shards,
test = "//test/syscalls/linux:timerfd_test",
)
@@ -897,13 +908,14 @@ syscall_test(
)
syscall_test(
+ add_hostinet = True,
test = "//test/syscalls/linux:udp_bind_test",
)
syscall_test(
size = "medium",
add_hostinet = True,
- shard_count = 10,
+ shard_count = more_shards,
test = "//test/syscalls/linux:udp_socket_test",
)
@@ -947,7 +959,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:wait_test",
)
@@ -961,6 +973,7 @@ syscall_test(
)
syscall_test(
+ add_hostinet = True,
test = "//test/syscalls/linux:proc_net_tcp_test",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 50baafbf7..760456a98 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -432,6 +432,9 @@ cc_binary(
testonly = 1,
srcs = ["chown.cc"],
linkstatic = 1,
+ # We require additional UIDs for this test, so don't include the bazel
+ # sandbox as standard.
+ tags = ["no-sandbox"],
deps = [
"//test/util:capability_util",
"//test/util:file_descriptor",
@@ -618,10 +621,7 @@ cc_binary(
cc_binary(
name = "exceptions_test",
testonly = 1,
- srcs = select_arch(
- amd64 = ["exceptions.cc"],
- arm64 = [],
- ),
+ srcs = ["exceptions.cc"],
linkstatic = 1,
deps = [
gtest,
@@ -796,8 +796,8 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:cleanup",
- "//test/util:epoll_util",
"//test/util:eventfd_util",
+ "//test/util:file_descriptor",
"//test/util:fs_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/flags:flag",
@@ -808,6 +808,7 @@ cc_binary(
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:save_util",
+ "//test/util:signal_util",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
@@ -2450,6 +2451,27 @@ cc_library(
)
cc_library(
+ name = "socket_ipv6_udp_unbound_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ipv6_udp_unbound.cc",
+ ],
+ hdrs = [
+ "socket_ipv6_udp_unbound.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "@com_google_absl//absl/memory",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "socket_ipv4_udp_unbound_netlink_test_cases",
testonly = 1,
srcs = [
@@ -2789,6 +2811,22 @@ cc_binary(
)
cc_binary(
+ name = "socket_ipv6_udp_unbound_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ipv6_udp_unbound_loopback.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ipv6_udp_unbound_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "socket_ipv4_udp_unbound_loopback_nogotsan_test",
testonly = 1,
srcs = [
diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc
index 7a28b674d..5530ad18f 100644
--- a/test/syscalls/linux/chown.cc
+++ b/test/syscalls/linux/chown.cc
@@ -75,7 +75,16 @@ TEST_P(ChownParamTest, ChownFileSucceeds) {
if (num_groups > 0) {
std::vector<gid_t> list(num_groups);
EXPECT_THAT(getgroups(list.size(), list.data()), SyscallSucceeds());
- gid = list[0];
+ // Scan the list of groups for a valid gid. Note that if a group is not
+ // defined in this local user namespace, then we will see 65534, and the
+ // group will not chown below as expected. So only change if we find a
+ // valid group in this list.
+ for (const gid_t other_gid : list) {
+ if (other_gid != 65534) {
+ gid = other_gid;
+ break;
+ }
+ }
}
EXPECT_NO_ERRNO(GetParam()(file.path(), geteuid(), gid));
@@ -90,6 +99,7 @@ TEST_P(ChownParamTest, ChownFilePermissionDenied) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0777));
+ EXPECT_THAT(chmod(GetAbsoluteTestTmpdir().c_str(), 0777), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -119,6 +129,7 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_SETUID))));
const std::string filename = NewTempAbsPath();
+ EXPECT_THAT(chmod(GetAbsoluteTestTmpdir().c_str(), 0777), SyscallSucceeds());
absl::Notification fileCreated, fileChowned;
// Change UID only in child thread, or else this parent thread won't be able
diff --git a/test/syscalls/linux/exceptions.cc b/test/syscalls/linux/exceptions.cc
index 420b9543f..11dc1c651 100644
--- a/test/syscalls/linux/exceptions.cc
+++ b/test/syscalls/linux/exceptions.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+#if defined(__x86_64__)
// Default value for the x87 FPU control word. See Intel SDM Vol 1, Ch 8.1.5
// "x87 FPU Control Word".
constexpr uint16_t kX87ControlWordDefault = 0x37f;
@@ -93,6 +94,9 @@ void InIOHelper(int width, int value) {
},
::testing::KilledBySignal(SIGSEGV), "");
}
+#elif defined(__aarch64__)
+void inline Halt() { asm("hlt #0\r\n"); }
+#endif
TEST(ExceptionTest, Halt) {
// In order to prevent the regular handler from messing with things (and
@@ -102,9 +106,14 @@ TEST(ExceptionTest, Halt) {
sa.sa_handler = SIG_DFL;
auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa));
+#if defined(__x86_64__)
EXPECT_EXIT(Halt(), ::testing::KilledBySignal(SIGSEGV), "");
+#elif defined(__aarch64__)
+ EXPECT_EXIT(Halt(), ::testing::KilledBySignal(SIGILL), "");
+#endif
}
+#if defined(__x86_64__)
TEST(ExceptionTest, DivideByZero) {
// See above.
struct sigaction sa = {};
@@ -362,6 +371,7 @@ TEST(ExceptionTest, Int3Compact) {
EXPECT_EXIT(Int3Compact(), ::testing::KilledBySignal(SIGTRAP), "");
}
+#endif
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 34016d4bd..4b581045b 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -14,10 +14,13 @@
#include <fcntl.h>
#include <signal.h>
+#include <sys/epoll.h>
#include <sys/types.h>
#include <syscall.h>
#include <unistd.h>
+#include <atomic>
+#include <deque>
#include <iostream>
#include <list>
#include <string>
@@ -34,25 +37,27 @@
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/cleanup.h"
#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
#include "test/util/save_util.h"
+#include "test/util/signal_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
#include "test/util/timer_util.h"
-ABSL_FLAG(std::string, child_setlock_on, "",
+ABSL_FLAG(std::string, child_set_lock_on, "",
"Contains the path to try to set a file lock on.");
-ABSL_FLAG(bool, child_setlock_write, false,
+ABSL_FLAG(bool, child_set_lock_write, false,
"Whether to set a writable lock (otherwise readable)");
ABSL_FLAG(bool, blocking, false,
"Whether to set a blocking lock (otherwise non-blocking).");
ABSL_FLAG(bool, retry_eintr, false,
"Whether to retry in the subprocess on EINTR.");
-ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start");
-ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len");
+ABSL_FLAG(uint64_t, child_set_lock_start, 0, "The value of struct flock start");
+ABSL_FLAG(uint64_t, child_set_lock_len, 0, "The value of struct flock len");
ABSL_FLAG(int32_t, socket_fd, -1,
"A socket to use for communicating more state back "
"to the parent.");
@@ -60,6 +65,11 @@ ABSL_FLAG(int32_t, socket_fd, -1,
namespace gvisor {
namespace testing {
+std::function<void(int, siginfo_t*, void*)> setsig_signal_handle;
+void setsig_signal_handler(int signum, siginfo_t* siginfo, void* ucontext) {
+ setsig_signal_handle(signum, siginfo, ucontext);
+}
+
class FcntlLockTest : public ::testing::Test {
public:
void SetUp() override {
@@ -84,18 +94,93 @@ class FcntlLockTest : public ::testing::Test {
int fds_[2] = {};
};
+struct SignalDelivery {
+ int num;
+ siginfo_t info;
+};
+
+class FcntlSignalTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ int pipe_fds[2];
+ ASSERT_THAT(pipe2(pipe_fds, O_NONBLOCK), SyscallSucceeds());
+ pipe_read_fd_ = pipe_fds[0];
+ pipe_write_fd_ = pipe_fds[1];
+ }
+
+ PosixErrorOr<Cleanup> RegisterSignalHandler(int signum) {
+ struct sigaction handler;
+ handler.sa_sigaction = setsig_signal_handler;
+ setsig_signal_handle = [&](int signum, siginfo_t* siginfo,
+ void* unused_ucontext) {
+ SignalDelivery sig;
+ sig.num = signum;
+ sig.info = *siginfo;
+ signals_received_.push_back(sig);
+ num_signals_received_++;
+ };
+ sigemptyset(&handler.sa_mask);
+ handler.sa_flags = SA_SIGINFO;
+ return ScopedSigaction(signum, handler);
+ }
+
+ void FlushAndCloseFD(int fd) {
+ char buf;
+ int read_bytes;
+ do {
+ read_bytes = read(fd, &buf, 1);
+ } while (read_bytes > 0);
+ // read() can also fail with EWOULDBLOCK since the pipe is open in
+ // non-blocking mode. This is not an error.
+ EXPECT_TRUE(read_bytes == 0 || (read_bytes == -1 && errno == EWOULDBLOCK));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ }
+
+ void DupReadFD() {
+ ASSERT_THAT(pipe_read_fd_dup_ = dup(pipe_read_fd_), SyscallSucceeds());
+ max_expected_signals++;
+ }
+
+ void RegisterFD(int fd, int signum) {
+ ASSERT_THAT(fcntl(fd, F_SETOWN, getpid()), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd, F_SETSIG, signum), SyscallSucceeds());
+ int old_flags;
+ ASSERT_THAT(old_flags = fcntl(fd, F_GETFL), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd, F_SETFL, old_flags | O_ASYNC), SyscallSucceeds());
+ }
+
+ void GenerateIOEvent() {
+ ASSERT_THAT(write(pipe_write_fd_, "test", 4), SyscallSucceedsWithValue(4));
+ }
+
+ void WaitForSignalDelivery(absl::Duration timeout) {
+ absl::Time wait_start = absl::Now();
+ while (num_signals_received_ < max_expected_signals &&
+ absl::Now() - wait_start < timeout) {
+ absl::SleepFor(absl::Milliseconds(10));
+ }
+ }
+
+ int pipe_read_fd_ = -1;
+ int pipe_read_fd_dup_ = -1;
+ int pipe_write_fd_ = -1;
+ int max_expected_signals = 1;
+ std::deque<SignalDelivery> signals_received_;
+ std::atomic<int> num_signals_received_ = 0;
+};
+
namespace {
PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write,
bool blocking, bool retry_eintr, int fd,
off_t start, off_t length, pid_t* child) {
std::vector<std::string> args = {
- "/proc/self/exe", "--child_setlock_on", path,
- "--child_setlock_start", absl::StrCat(start), "--child_setlock_len",
- absl::StrCat(length), "--socket_fd", absl::StrCat(fd)};
+ "/proc/self/exe", "--child_set_lock_on", path,
+ "--child_set_lock_start", absl::StrCat(start), "--child_set_lock_len",
+ absl::StrCat(length), "--socket_fd", absl::StrCat(fd)};
if (for_write) {
- args.push_back("--child_setlock_write");
+ args.push_back("--child_set_lock_write");
}
if (blocking) {
@@ -965,7 +1050,6 @@ TEST(FcntlTest, GetOwnNone) {
// into F_{GET,SET}OWN_EX.
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(0));
- MaybeSave();
}
TEST(FcntlTest, GetOwnExNone) {
@@ -1009,7 +1093,6 @@ TEST(FcntlTest, SetOwnPid) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(pid));
- MaybeSave();
}
TEST(FcntlTest, SetOwnPgrp) {
@@ -1030,7 +1113,6 @@ TEST(FcntlTest, SetOwnPgrp) {
SyscallSucceedsWithValue(0));
EXPECT_EQ(got_owner.type, F_OWNER_PGRP);
EXPECT_EQ(got_owner.pid, pgid);
- MaybeSave();
}
TEST(FcntlTest, SetOwnUnset) {
@@ -1058,7 +1140,6 @@ TEST(FcntlTest, SetOwnUnset) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(0));
- MaybeSave();
}
// F_SETOWN flips the sign of negative values, an operation that is guarded
@@ -1130,7 +1211,6 @@ TEST(FcntlTest, SetOwnExTid) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(owner.pid));
- MaybeSave();
}
TEST(FcntlTest, SetOwnExPid) {
@@ -1146,7 +1226,6 @@ TEST(FcntlTest, SetOwnExPid) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(owner.pid));
- MaybeSave();
}
TEST(FcntlTest, SetOwnExPgrp) {
@@ -1168,7 +1247,6 @@ TEST(FcntlTest, SetOwnExPgrp) {
SyscallSucceedsWithValue(0));
EXPECT_EQ(got_owner.type, set_owner.type);
EXPECT_EQ(got_owner.pid, set_owner.pid);
- MaybeSave();
}
TEST(FcntlTest, SetOwnExUnset) {
@@ -1201,7 +1279,6 @@ TEST(FcntlTest, SetOwnExUnset) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(0));
- MaybeSave();
}
TEST(FcntlTest, GetOwnExTid) {
@@ -1258,9 +1335,269 @@ TEST(FcntlTest, GetOwnExPgrp) {
EXPECT_EQ(got_owner.pid, set_owner.pid);
}
+TEST(FcntlTest, SetSig) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGUSR1),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(SIGUSR1));
+}
+
+TEST(FcntlTest, SetSigDefaultsToZero) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Defaults to returning the zero value, indicating default behavior (SIGIO).
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, SetSigToDefault) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGIO),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(SIGIO));
+
+ // Can be reset to the default behavior.
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, 0),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, SetSigInvalid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGRTMAX + 1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, SetSigInvalidDoesNotResetPreviousChoice) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGUSR1),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGRTMAX + 1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(SIGUSR1));
+}
+
+TEST_F(FcntlSignalTest, SetSigDefault) {
+ const auto signal_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ RegisterFD(pipe_read_fd_, 0); // Zero = default behavior
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ signals_received_.pop_front();
+ EXPECT_EQ(sig.num, SIGIO);
+ EXPECT_EQ(sig.info.si_signo, SIGIO);
+ // siginfo contents is undefined in this case.
+}
+
+TEST_F(FcntlSignalTest, SetSigCustom) {
+ const auto signal_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ signals_received_.pop_front();
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigUnregisterStillGetsSigio) {
+ const auto sigio_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ RegisterFD(pipe_read_fd_, 0);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ signals_received_.pop_front();
+ EXPECT_EQ(sig.num, SIGIO);
+ // siginfo contents is undefined in this case.
+}
+
+TEST_F(FcntlSignalTest, SetSigWithSigioStillGetsSiginfo) {
+ const auto signal_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ RegisterFD(pipe_read_fd_, SIGIO);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ EXPECT_EQ(sig.num, SIGIO);
+ EXPECT_EQ(sig.info.si_signo, SIGIO);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupThenCloseOld) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ FlushAndCloseFD(pipe_read_fd_);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the **old** FD (even though it is closed).
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupThenCloseNew) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ FlushAndCloseFD(pipe_read_fd_dup_);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the old FD.
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupOldRegistered) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the old FD.
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupNewRegistered) {
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the new FD.
+ EXPECT_EQ(sig.num, SIGUSR2);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR2);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_dup_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupBothRegistered) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the **new** signal number, but the **old** FD.
+ EXPECT_EQ(sig.num, SIGUSR2);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR2);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupBothRegisteredAfterDup) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ DupReadFD();
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the **new** signal number, but the **old** FD.
+ EXPECT_EQ(sig.num, SIGUSR2);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR2);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupUnregisterOld) {
+ const auto sigio_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ RegisterFD(pipe_read_fd_, 0); // Should go back to SIGIO behavior.
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with SIGIO.
+ EXPECT_EQ(sig.num, SIGIO);
+ // siginfo is undefined in this case.
+}
+
+TEST_F(FcntlSignalTest, SetSigDupUnregisterNew) {
+ const auto sigio_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ RegisterFD(pipe_read_fd_dup_, 0); // Should go back to SIGIO behavior.
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with SIGIO.
+ EXPECT_EQ(sig.num, SIGIO);
+ // siginfo is undefined in this case.
+}
+
// Make sure that making multiple concurrent changes to async signal generation
// does not cause any race issues.
-TEST(FcntlTest, SetFlSetOwnDoNotRace) {
+TEST(FcntlTest, SetFlSetOwnSetSigDoNotRace) {
FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
@@ -1268,32 +1605,40 @@ TEST(FcntlTest, SetFlSetOwnDoNotRace) {
EXPECT_THAT(pid = getpid(), SyscallSucceeds());
constexpr absl::Duration runtime = absl::Milliseconds(300);
- auto setAsync = [&s, &runtime] {
+ auto set_async = [&s, &runtime] {
for (auto start = absl::Now(); absl::Now() - start < runtime;) {
ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, O_ASYNC),
SyscallSucceeds());
sched_yield();
}
};
- auto resetAsync = [&s, &runtime] {
+ auto reset_async = [&s, &runtime] {
for (auto start = absl::Now(); absl::Now() - start < runtime;) {
ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, 0), SyscallSucceeds());
sched_yield();
}
};
- auto setOwn = [&s, &pid, &runtime] {
+ auto set_own = [&s, &pid, &runtime] {
for (auto start = absl::Now(); absl::Now() - start < runtime;) {
ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid),
SyscallSucceeds());
sched_yield();
}
};
+ auto set_sig = [&s, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGUSR1),
+ SyscallSucceeds());
+ sched_yield();
+ }
+ };
std::list<ScopedThread> threads;
for (int i = 0; i < 10; i++) {
- threads.emplace_back(setAsync);
- threads.emplace_back(resetAsync);
- threads.emplace_back(setOwn);
+ threads.emplace_back(set_async);
+ threads.emplace_back(reset_async);
+ threads.emplace_back(set_own);
+ threads.emplace_back(set_sig);
}
}
@@ -1302,57 +1647,60 @@ TEST(FcntlTest, SetFlSetOwnDoNotRace) {
} // namespace testing
} // namespace gvisor
-int main(int argc, char** argv) {
- gvisor::testing::TestInit(&argc, &argv);
-
- const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on);
- if (!setlock_on.empty()) {
- int socket_fd = absl::GetFlag(FLAGS_socket_fd);
- int fd = open(setlock_on.c_str(), O_RDWR, 0666);
- if (fd == -1 && errno != 0) {
- int err = errno;
- std::cerr << "CHILD open " << setlock_on << " failed " << err
- << std::endl;
- exit(err);
- }
+int set_lock() {
+ const std::string set_lock_on = absl::GetFlag(FLAGS_child_set_lock_on);
+ int socket_fd = absl::GetFlag(FLAGS_socket_fd);
+ int fd = open(set_lock_on.c_str(), O_RDWR, 0666);
+ if (fd == -1 && errno != 0) {
+ int err = errno;
+ std::cerr << "CHILD open " << set_lock_on << " failed: " << err
+ << std::endl;
+ return err;
+ }
- struct flock fl;
- if (absl::GetFlag(FLAGS_child_setlock_write)) {
- fl.l_type = F_WRLCK;
- } else {
- fl.l_type = F_RDLCK;
- }
- fl.l_whence = SEEK_SET;
- fl.l_start = absl::GetFlag(FLAGS_child_setlock_start);
- fl.l_len = absl::GetFlag(FLAGS_child_setlock_len);
+ struct flock fl;
+ if (absl::GetFlag(FLAGS_child_set_lock_write)) {
+ fl.l_type = F_WRLCK;
+ } else {
+ fl.l_type = F_RDLCK;
+ }
+ fl.l_whence = SEEK_SET;
+ fl.l_start = absl::GetFlag(FLAGS_child_set_lock_start);
+ fl.l_len = absl::GetFlag(FLAGS_child_set_lock_len);
+
+ // Test the fcntl.
+ int err = 0;
+ int ret = 0;
+
+ gvisor::testing::MonotonicTimer timer;
+ timer.Start();
+ do {
+ ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
+ } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
+ auto usec = absl::ToInt64Microseconds(timer.Duration());
+
+ if (ret == -1 && errno != 0) {
+ err = errno;
+ std::cerr << "CHILD lock " << set_lock_on << " failed " << err << std::endl;
+ }
- // Test the fcntl.
- int err = 0;
- int ret = 0;
+ // If there is a socket fd let's send back the time in microseconds it took
+ // to execute this syscall.
+ if (socket_fd != -1) {
+ gvisor::testing::WriteFd(socket_fd, reinterpret_cast<void*>(&usec),
+ sizeof(usec));
+ close(socket_fd);
+ }
- gvisor::testing::MonotonicTimer timer;
- timer.Start();
- do {
- ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
- } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
- auto usec = absl::ToInt64Microseconds(timer.Duration());
-
- if (ret == -1 && errno != 0) {
- err = errno;
- std::cerr << "CHILD lock " << setlock_on << " failed " << err
- << std::endl;
- }
+ close(fd);
+ return err;
+}
- // If there is a socket fd let's send back the time in microseconds it took
- // to execute this syscall.
- if (socket_fd != -1) {
- gvisor::testing::WriteFd(socket_fd, reinterpret_cast<void*>(&usec),
- sizeof(usec));
- close(socket_fd);
- }
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
- close(fd);
- exit(err);
+ if (!absl::GetFlag(FLAGS_child_set_lock_on).empty()) {
+ exit(set_lock());
}
return gvisor::testing::RunAllTests();
diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc
index db29bd59c..5d1735853 100644
--- a/test/syscalls/linux/kill.cc
+++ b/test/syscalls/linux/kill.cc
@@ -58,6 +58,12 @@ void SigHandler(int sig, siginfo_t* info, void* context) { _exit(0); }
// If pid equals -1, then sig is sent to every process for which the calling
// process has permission to send signals, except for process 1 (init).
TEST(KillTest, CanKillAllPIDs) {
+ // If we're not running inside the sandbox, then we skip this test
+ // as our namespace may contain may more processes that cannot tolerate
+ // the signal below. We also cannot reliably create a new pid namespace
+ // for ourselves and test the same functionality.
+ SKIP_IF(!IsRunningOnGvisor());
+
int pipe_fds[2];
ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
FileDescriptor read_fd(pipe_fds[0]);
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 78c36f98f..9d63782fb 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -112,14 +112,6 @@ TEST(CreateTest, CreatFileWithOTruncAndReadOnly) {
ASSERT_THAT(close(dirfd), SyscallSucceeds());
}
-TEST(CreateTest, CreateFailsOnUnpermittedDir) {
- // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
- // always override directory permissions.
- ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
- ASSERT_THAT(open("/foo", O_CREAT | O_RDWR, 0644),
- SyscallFailsWithErrno(EACCES));
-}
-
TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) {
// Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
// always override directory permissions.
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 7a0f33dff..575be014c 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -17,6 +17,7 @@
#include <fcntl.h>
#include <limits.h>
#include <linux/magic.h>
+#include <linux/sem.h>
#include <sched.h>
#include <signal.h>
#include <stddef.h>
@@ -2409,6 +2410,28 @@ TEST(ProcFilesystems, PresenceOfShmMaxMniAll) {
ASSERT_LE(shmall, ULONG_MAX - (1UL << 24));
}
+TEST(ProcFilesystems, PresenceOfSem) {
+ uint32_t semmsl = 0;
+ uint32_t semmns = 0;
+ uint32_t semopm = 0;
+ uint32_t semmni = 0;
+ std::string proc_file;
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/sem"));
+ ASSERT_FALSE(proc_file.empty());
+ std::vector<absl::string_view> sem_limits =
+ absl::StrSplit(proc_file, absl::ByAnyChar("\t"), absl::SkipWhitespace());
+ ASSERT_EQ(sem_limits.size(), 4);
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[0], &semmsl));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[1], &semmns));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[2], &semopm));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[3], &semmni));
+
+ ASSERT_EQ(semmsl, SEMMSL);
+ ASSERT_EQ(semmns, SEMMNS);
+ ASSERT_EQ(semopm, SEMOPM);
+ ASSERT_EQ(semmni, SEMMNI);
+}
+
// Check that /proc/mounts is a symlink to self/mounts.
TEST(ProcMounts, IsSymlink) {
auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/mounts"));
@@ -2459,7 +2482,7 @@ void CheckDuplicatesRecursively(std::string path) {
return;
}
auto dir_closer = Cleanup([&dir]() { closedir(dir); });
- std::unordered_set<std::string> children;
+ absl::node_hash_set<std::string> children;
while (true) {
// Readdir(3): If the end of the directory stream is reached, NULL is
// returned and errno is not changed. If an error occurs, NULL is
@@ -2478,6 +2501,10 @@ void CheckDuplicatesRecursively(std::string path) {
absl::EndsWith(path, "/net")) {
break;
}
+ // We may also see permission failures traversing some files.
+ if (errno == EACCES && absl::StartsWith(path, "/proc/")) {
+ break;
+ }
// Otherwise, no errors are allowed.
ASSERT_EQ(errno, 0) << path;
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 54709371c..955bcee4b 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -852,6 +852,51 @@ TEST(RawSocketTest, IPv6ProtoRaw) {
SyscallFailsWithErrno(EINVAL));
}
+TEST(RawSocketTest, IPv6SendMsg) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_TCP),
+ SyscallSucceeds());
+
+ char kBuf[] = "hello";
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<char*>(kBuf));
+ iov.iov_len = static_cast<size_t>(sizeof(kBuf));
+
+ struct sockaddr_storage addr = {};
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+ struct msghdr msg = {};
+ msg.msg_name = static_cast<void*>(&addr);
+ msg.msg_namelen = sizeof(sockaddr_in);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(RawSocketTest, ConnectOnIPv6Socket) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_TCP),
+ SyscallSucceeds());
+
+ struct sockaddr_storage addr = {};
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+ ASSERT_THAT(connect(sock, reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(sockaddr_in6)),
+ SyscallFailsWithErrno(EAFNOSUPPORT));
+}
+
INSTANTIATE_TEST_SUITE_P(
AllInetTests, RawSocketTest,
::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 890f4a246..c2f080917 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -20,6 +20,7 @@
#include <atomic>
#include <cerrno>
#include <ctime>
+#include <set>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -31,10 +32,23 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+using ::testing::Contains;
+
namespace gvisor {
namespace testing {
namespace {
+constexpr int kSemMap = 1024000000;
+constexpr int kSemMni = 32000;
+constexpr int kSemMns = 1024000000;
+constexpr int kSemMnu = 1024000000;
+constexpr int kSemMsl = 32000;
+constexpr int kSemOpm = 500;
+constexpr int kSemUme = 500;
+constexpr int kSemUsz = 20;
+constexpr int kSemVmx = 32767;
+constexpr int kSemAem = 32767;
+
class AutoSem {
public:
explicit AutoSem(int id) : id_(id) {}
@@ -773,6 +787,154 @@ TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) {
EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0);
}
+TEST(SemaphoreTest, IpcInfo) {
+ constexpr int kLoops = 5;
+ std::set<int> sem_ids;
+ struct seminfo info;
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceedsWithValue(0));
+ for (int i = 0; i < kLoops; i++) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ sem_ids.insert(sem.release());
+ }
+ ASSERT_EQ(sem_ids.size(), kLoops);
+
+ int max_used_index = 0;
+ EXPECT_THAT(max_used_index = semctl(0, 0, IPC_INFO, &info),
+ SyscallSucceeds());
+
+ int index_count = 0;
+ for (int i = 0; i <= max_used_index; i++) {
+ struct semid_ds ds = {};
+ int sem_id = semctl(i, 0, SEM_STAT, &ds);
+ // Only if index i is used within the registry.
+ if (sem_id != -1) {
+ ASSERT_THAT(sem_ids, Contains(sem_id));
+ struct semid_ds ipc_stat_ds;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds());
+ EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key);
+ EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid);
+ EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid);
+ EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid);
+ EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid);
+ EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode);
+ EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime);
+ EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime);
+ EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems);
+
+ // Remove the semaphore set's read permission.
+ struct semid_ds ipc_set_ds;
+ ipc_set_ds.sem_perm.uid = getuid();
+ ipc_set_ds.sem_perm.gid = getgid();
+ // Keep the semaphore set's write permission so that it could be removed.
+ ipc_set_ds.sem_perm.mode = 0200;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds());
+ ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES));
+
+ index_count += 1;
+ }
+ }
+ EXPECT_EQ(index_count, kLoops);
+ ASSERT_THAT(semctl(0, 0, IPC_INFO, &info),
+ SyscallSucceedsWithValue(max_used_index));
+ for (const int sem_id : sem_ids) {
+ ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(info.semmap, kSemMap);
+ EXPECT_EQ(info.semmni, kSemMni);
+ EXPECT_EQ(info.semmns, kSemMns);
+ EXPECT_EQ(info.semmnu, kSemMnu);
+ EXPECT_EQ(info.semmsl, kSemMsl);
+ EXPECT_EQ(info.semopm, kSemOpm);
+ EXPECT_EQ(info.semume, kSemUme);
+ EXPECT_EQ(info.semusz, kSemUsz);
+ EXPECT_EQ(info.semvmx, kSemVmx);
+ EXPECT_EQ(info.semaem, kSemAem);
+}
+
+TEST(SemaphoreTest, SemInfo) {
+ constexpr int kLoops = 5;
+ constexpr int kSemSetSize = 3;
+ std::set<int> sem_ids;
+ struct seminfo info;
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceedsWithValue(0));
+ for (int i = 0; i < kLoops; i++) {
+ AutoSem sem(semget(IPC_PRIVATE, kSemSetSize, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ sem_ids.insert(sem.release());
+ }
+ ASSERT_EQ(sem_ids.size(), kLoops);
+ int max_used_index = 0;
+ EXPECT_THAT(max_used_index = semctl(0, 0, SEM_INFO, &info),
+ SyscallSucceeds());
+ EXPECT_EQ(info.semmap, kSemMap);
+ EXPECT_EQ(info.semmni, kSemMni);
+ EXPECT_EQ(info.semmns, kSemMns);
+ EXPECT_EQ(info.semmnu, kSemMnu);
+ EXPECT_EQ(info.semmsl, kSemMsl);
+ EXPECT_EQ(info.semopm, kSemOpm);
+ EXPECT_EQ(info.semume, kSemUme);
+ EXPECT_EQ(info.semusz, sem_ids.size());
+ EXPECT_EQ(info.semvmx, kSemVmx);
+ EXPECT_EQ(info.semaem, sem_ids.size() * kSemSetSize);
+
+ int index_count = 0;
+ for (int i = 0; i <= max_used_index; i++) {
+ struct semid_ds ds = {};
+ int sem_id = semctl(i, 0, SEM_STAT, &ds);
+ // Only if index i is used within the registry.
+ if (sem_id != -1) {
+ ASSERT_THAT(sem_ids, Contains(sem_id));
+ struct semid_ds ipc_stat_ds;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds());
+ EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key);
+ EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid);
+ EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid);
+ EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid);
+ EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid);
+ EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode);
+ EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime);
+ EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime);
+ EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems);
+
+ // Remove the semaphore set's read permission.
+ struct semid_ds ipc_set_ds;
+ ipc_set_ds.sem_perm.uid = getuid();
+ ipc_set_ds.sem_perm.gid = getgid();
+ // Keep the semaphore set's write permission so that it could be removed.
+ ipc_set_ds.sem_perm.mode = 0200;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds());
+ ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES));
+
+ index_count += 1;
+ }
+ }
+ EXPECT_EQ(index_count, kLoops);
+ ASSERT_THAT(semctl(0, 0, SEM_INFO, &info),
+ SyscallSucceedsWithValue(max_used_index));
+ for (const int sem_id : sem_ids) {
+ ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(semctl(0, 0, SEM_INFO, &info), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(info.semmap, kSemMap);
+ EXPECT_EQ(info.semmni, kSemMni);
+ EXPECT_EQ(info.semmns, kSemMns);
+ EXPECT_EQ(info.semmnu, kSemMnu);
+ EXPECT_EQ(info.semmsl, kSemMsl);
+ EXPECT_EQ(info.semopm, kSemOpm);
+ EXPECT_EQ(info.semume, kSemUme);
+ EXPECT_EQ(info.semusz, 0);
+ EXPECT_EQ(info.semvmx, kSemVmx);
+ EXPECT_EQ(info.semaem, 0);
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
index 389e5fca2..c86cd2755 100644
--- a/test/syscalls/linux/signalfd.cc
+++ b/test/syscalls/linux/signalfd.cc
@@ -126,7 +126,7 @@ TEST_P(SignalfdTest, Blocking) {
// Shared tid variable.
absl::Mutex mu;
- bool has_tid;
+ bool has_tid = false;
pid_t tid;
// Start a thread reading.
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index 796546224..a28ee2233 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -818,32 +818,55 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) {
}
}
-TEST_P(AllSocketPairTest, GetSockoptBroadcast) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- int opt = -1;
- socklen_t optlen = sizeof(opt);
- EXPECT_THAT(
- getsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, &opt, &optlen),
- SyscallSucceeds());
- ASSERT_EQ(optlen, sizeof(opt));
- EXPECT_EQ(opt, 0);
+TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) {
+ int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK,
+ SO_REUSEADDR, SO_REUSEPORT, SO_KEEPALIVE};
+ for (int sock_opt : sock_opts) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int enable = -1;
+ socklen_t enableLen = sizeof(enable);
+
+ // Test that the option is initially set to false.
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable,
+ &enableLen),
+ SyscallSucceeds());
+ ASSERT_EQ(enableLen, sizeof(enable));
+ EXPECT_EQ(enable, 0) << absl::StrFormat(
+ "getsockopt(fd, SOL_SOCKET, %d, &enable, &enableLen) => enable=%d",
+ sock_opt, enable);
+
+ // Test that setting the option to true is reflected in the subsequent
+ // call to getsockopt(2).
+ enable = 1;
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable,
+ sizeof(enable)),
+ SyscallSucceeds());
+ enable = -1;
+ enableLen = sizeof(enable);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable,
+ &enableLen),
+ SyscallSucceeds());
+ ASSERT_EQ(enableLen, sizeof(enable));
+ EXPECT_EQ(enable, 1) << absl::StrFormat(
+ "getsockopt(fd, SOL_SOCKET, %d, &enable, &enableLen) => enable=%d",
+ sock_opt, enable);
+ }
}
-TEST_P(AllSocketPairTest, SetAndGetSocketBroadcastOption) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- int kSockOptOn = 1;
- ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST,
- &kSockOptOn, sizeof(kSockOptOn)),
- SyscallSucceedsWithValue(0));
+TEST_P(AllSocketPairTest, GetSocketOutOfBandInlineOption) {
+ // We do not support disabling this option. It is always enabled.
+ SKIP_IF(!IsRunningOnGvisor());
- int got = -1;
- socklen_t length = sizeof(got);
- ASSERT_THAT(
- getsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, &got, &length),
- SyscallSucceedsWithValue(0));
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int enable = -1;
+ socklen_t enableLen = sizeof(enable);
- ASSERT_EQ(length, sizeof(got));
- EXPECT_EQ(got, kSockOptOn);
+ int want = 1;
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_OOBINLINE, &enable,
+ &enableLen),
+ SyscallSucceeds());
+ ASSERT_EQ(enableLen, sizeof(enable));
+ EXPECT_EQ(enable, want);
}
} // namespace testing
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index e19a83413..51b77ad85 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -1185,19 +1185,44 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen));
ASSERT_EQ(addrlen, listener.addr_len);
- // TODO(gvisor.dev/issue/3812): Remove after SO_ERROR is fixed.
- if (IsRunningOnGvisor()) {
- char buf[10];
- ASSERT_THAT(ReadFd(accept_fd.get(), buf, sizeof(buf)),
- SyscallFailsWithErrno(ECONNRESET));
- } else {
+ // Wait for accept_fd to process the RST.
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = accept_fd.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+
+ {
int err;
socklen_t optlen = sizeof(err);
ASSERT_THAT(
getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
SyscallSucceeds());
+ // This should return ECONNRESET as the socket just received a RST packet
+ // from the peer.
+ ASSERT_EQ(optlen, sizeof(err));
ASSERT_EQ(err, ECONNRESET);
+ }
+ {
+ int err;
+ socklen_t optlen = sizeof(err);
+ ASSERT_THAT(
+ getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
+ SyscallSucceeds());
+ // This should return no error as the previous getsockopt call would have
+ // cleared the socket error.
ASSERT_EQ(optlen, sizeof(err));
+ ASSERT_EQ(err, 0);
+ }
+ {
+ sockaddr_storage peer_addr;
+ socklen_t addrlen = sizeof(peer_addr);
+ // The socket is not connected anymore and should return ENOTCONN.
+ ASSERT_THAT(getpeername(accept_fd.get(),
+ reinterpret_cast<sockaddr*>(&peer_addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
}
}
@@ -2805,5 +2830,28 @@ INSTANTIATE_TEST_SUITE_P(
} // namespace
+// Check that loopback receives connections from any address in the range:
+// 127.0.0.1 to 127.254.255.255. This behavior is exclusive to IPv4.
+TEST_F(SocketInetLoopbackTest, LoopbackAddressRangeConnect) {
+ TestAddress const& listener = V4Any();
+
+ in_addr_t addresses[] = {
+ INADDR_LOOPBACK,
+ INADDR_LOOPBACK + 1, // 127.0.0.2
+ (in_addr_t)0x7f000101, // 127.0.1.1
+ (in_addr_t)0x7f010101, // 127.1.1.1
+ (in_addr_t)0x7ffeffff, // 127.254.255.255
+ };
+ for (const auto& address : addresses) {
+ TestAddress connector("V4Loopback");
+ connector.addr.ss_family = AF_INET;
+ connector.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&connector.addr)->sin_addr.s_addr =
+ htonl(address);
+
+ tcpSimpleConnectTest(listener, connector, true);
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index f69f8f99f..2fcd08112 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -15,6 +15,9 @@
#include "test/syscalls/linux/socket_ip_udp_generic.h"
#include <errno.h>
+#ifdef __linux__
+#include <linux/in6.h>
+#endif // __linux__
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
@@ -356,6 +359,58 @@ TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) {
EXPECT_EQ(get_len, sizeof(get));
}
+// Test getsockopt for a socket which is not set with IP_RECVORIGDSTADDR option.
+TEST_P(UDPSocketPairTest, ReceiveOrigDstAddrDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ int level = SOL_IP;
+ int type = IP_RECVORIGDSTADDR;
+ if (sockets->first_addr()->sa_family == AF_INET6) {
+ level = SOL_IPV6;
+ type = IPV6_RECVORIGDSTADDR;
+ }
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test setsockopt and getsockopt for a socket with IP_RECVORIGDSTADDR option.
+TEST_P(UDPSocketPairTest, SetAndGetReceiveOrigDstAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int level = SOL_IP;
+ int type = IP_RECVORIGDSTADDR;
+ if (sockets->first_addr()->sa_family == AF_INET6) {
+ level = SOL_IPV6;
+ type = IPV6_RECVORIGDSTADDR;
+ }
+
+ // Check getsockopt before IP_PKTINFO is set.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_len, sizeof(get));
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_len, sizeof(get));
+}
+
// Holds TOS or TClass information for IPv4 or IPv6 respectively.
struct RecvTosOption {
int level;
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index b3f54e7f6..e557572a7 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -2222,6 +2222,90 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK));
}
+// Test that socket will receive IP_RECVORIGDSTADDR control message.
+TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver_addr = V4Loopback();
+ int level = SOL_IP;
+ int type = IP_RECVORIGDSTADDR;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Retrieve the port bound by the receiver.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Get address and port bound by the sender.
+ sockaddr_storage sender_addr_storage;
+ socklen_t sender_addr_len = sizeof(sender_addr_storage);
+ ASSERT_THAT(getsockname(sender->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr_storage),
+ &sender_addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in));
+
+ // Enable IP_RECVORIGDSTADDR on socket so that we get the original destination
+ // address of the datagram as auxiliary information in the control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(sockaddr_in))] = {};
+ size_t cmsg_data_len = sizeof(sockaddr_in);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Check the data
+ sockaddr_in received_addr = {};
+ memcpy(&received_addr, CMSG_DATA(cmsg), sizeof(received_addr));
+ auto orig_receiver_addr = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr);
+ EXPECT_EQ(received_addr.sin_addr.s_addr, orig_receiver_addr->sin_addr.s_addr);
+ EXPECT_EQ(received_addr.sin_port, orig_receiver_addr->sin_port);
+}
+
// Check that setting SO_RCVBUF below min is clamped to the minimum
// receive buffer size.
TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufBelowMin) {
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.cc b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
new file mode 100644
index 000000000..08526468e
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv6_udp_unbound.h"
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#ifdef __linux__
+#include <linux/in6.h>
+#endif // __linux__
+#include <net/if.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test that socket will receive IP_RECVORIGDSTADDR control message.
+TEST_P(IPv6UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver_addr = V6Loopback();
+ int level = SOL_IPV6;
+ int type = IPV6_RECVORIGDSTADDR;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Retrieve the port bound by the receiver.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Get address and port bound by the sender.
+ sockaddr_storage sender_addr_storage;
+ socklen_t sender_addr_len = sizeof(sender_addr_storage);
+ ASSERT_THAT(getsockname(sender->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr_storage),
+ &sender_addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in6));
+
+ // Enable IP_RECVORIGDSTADDR on socket so that we get the original destination
+ // address of the datagram as auxiliary information in the control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(sockaddr_in6))] = {};
+ size_t cmsg_data_len = sizeof(sockaddr_in6);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Check that the received address in the control message matches the expected
+ // receiver's address.
+ sockaddr_in6 received_addr = {};
+ memcpy(&received_addr, CMSG_DATA(cmsg), sizeof(received_addr));
+ auto orig_receiver_addr =
+ reinterpret_cast<sockaddr_in6*>(&receiver_addr.addr);
+ EXPECT_EQ(memcmp(&received_addr.sin6_addr, &orig_receiver_addr->sin6_addr,
+ sizeof(in6_addr)),
+ 0);
+ EXPECT_EQ(received_addr.sin6_port, orig_receiver_addr->sin6_port);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.h b/test/syscalls/linux/socket_ipv6_udp_unbound.h
new file mode 100644
index 000000000..71e160f73
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound.h
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to IPv6 UDP sockets.
+using IPv6UDPUnboundSocketTest = SimpleSocketTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc
new file mode 100644
index 000000000..058336ecc
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_ipv6_udp_unbound.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+INSTANTIATE_TEST_SUITE_P(
+ IPv6UDPSockets, IPv6UDPUnboundSocketTest,
+ ::testing::ValuesIn(ApplyVec<SocketKind>(IPv6UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{
+ 0, SOCK_NONBLOCK}))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index bc2c8278c..714848b8e 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -964,37 +964,156 @@ TEST_P(TcpSocketTest, PollAfterShutdown) {
SyscallSucceedsWithValue(1));
}
-TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectRetry) {
+ const FileDescriptor listener =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
// Initialize address to the loopback one.
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
- const FileDescriptor s =
+ // Bind to some port but don't listen yet.
+ ASSERT_THAT(
+ bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Get the address we're bound to, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listener.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ FileDescriptor connector =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
- // Set the FD to O_NONBLOCK.
- int opts;
- ASSERT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds());
- opts |= O_NONBLOCK;
- ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
+ // Verify that connect fails.
+ ASSERT_THAT(
+ RetryEINTR(connect)(connector.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
- ASSERT_THAT(RetryEINTR(connect)(
+ // Now start listening
+ ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
+
+ // TODO(gvisor.dev/issue/3828): Issuing connect() again on a socket that
+ // failed first connect should succeed.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(
+ RetryEINTR(connect)(connector.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNABORTED));
+ return;
+ }
+
+ // Verify that connect now succeeds.
+ ASSERT_THAT(
+ RetryEINTR(connect)(connector.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ const FileDescriptor accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listener.get(), nullptr, nullptr));
+}
+
+// nonBlockingConnectNoListener returns a socket on which a connect that is
+// expected to fail has been issued.
+PosixErrorOr<FileDescriptor> nonBlockingConnectNoListener(const int family,
+ sockaddr_storage addr,
+ socklen_t addrlen) {
+ // We will first create a socket and bind to ensure we bind a port but will
+ // not call listen on this socket.
+ // Then we will create a new socket that will connect to the port bound by
+ // the first socket and that shoud fail.
+ constexpr int sock_type = SOCK_STREAM | SOCK_NONBLOCK;
+ int b_sock;
+ RETURN_ERROR_IF_SYSCALL_FAIL(b_sock = socket(family, sock_type, IPPROTO_TCP));
+ FileDescriptor b(b_sock);
+ EXPECT_THAT(bind(b.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Get the address bound by the listening socket.
+ EXPECT_THAT(
+ getsockname(b.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ // Now create another socket and issue a connect on this one. This connect
+ // should fail as there is no listener.
+ int c_sock;
+ RETURN_ERROR_IF_SYSCALL_FAIL(c_sock = socket(family, sock_type, IPPROTO_TCP));
+ FileDescriptor s(c_sock);
+
+ // Now connect to the bound address and this should fail as nothing
+ // is listening on the bound address.
+ EXPECT_THAT(RetryEINTR(connect)(
s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallFailsWithErrno(EINPROGRESS));
- // Now polling on the FD with a timeout should return 0 corresponding to no
- // FDs ready.
- struct pollfd poll_fd = {s.get(), POLLOUT, 0};
- EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
- SyscallSucceedsWithValue(1));
+ // Wait for the connect to fail.
+ struct pollfd poll_fd = {s.get(), POLLERR, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 1000), SyscallSucceedsWithValue(1));
+ return std::move(s);
+}
+
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ const FileDescriptor s =
+ nonBlockingConnectNoListener(GetParam(), addr, addrlen).ValueOrDie();
int err;
socklen_t optlen = sizeof(err);
ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
SyscallSucceeds());
-
+ ASSERT_THAT(optlen, sizeof(err));
EXPECT_EQ(err, ECONNREFUSED);
+
+ unsigned char c;
+ ASSERT_THAT(read(s.get(), &c, sizeof(c)), SyscallSucceedsWithValue(0));
+ int opts;
+ EXPECT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds());
+ opts &= ~O_NONBLOCK;
+ EXPECT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
+ // Try connecting again.
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNABORTED));
+}
+
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerRead) {
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ const FileDescriptor s =
+ nonBlockingConnectNoListener(GetParam(), addr, addrlen).ValueOrDie();
+
+ unsigned char c;
+ ASSERT_THAT(read(s.get(), &c, 1), SyscallFailsWithErrno(ECONNREFUSED));
+ ASSERT_THAT(read(s.get(), &c, 1), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNABORTED));
+}
+
+TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerPeek) {
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ const FileDescriptor s =
+ nonBlockingConnectNoListener(GetParam(), addr, addrlen).ValueOrDie();
+
+ unsigned char c;
+ ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK),
+ SyscallFailsWithErrno(ECONNREFUSED));
+ ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(RetryEINTR(connect)(
+ s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNABORTED));
}
TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) {
@@ -1235,6 +1354,19 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) {
// Attempt #2, with the new socket and reused addr our connect should fail in
// the same way as before, not with an EADDRINUSE.
+ //
+ // TODO(gvisor.dev/issue/3828): 2nd connect on a socket which failed connect
+ // first time should succeed.
+ // gVisor never issues the second connect and returns ECONNABORTED instead.
+ // Linux actually sends a SYN again and gets a RST and correctly returns
+ // ECONNREFUSED.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(connect(client_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(ECONNABORTED));
+ return;
+ }
ASSERT_THAT(connect(client_s.get(),
reinterpret_cast<const struct sockaddr*>(&bound_addr),
bound_addrlen),
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index d65275fd3..90ef8bf21 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -374,6 +374,69 @@ TEST_P(UdpSocketTest, BindInUse) {
SyscallFailsWithErrno(EADDRINUSE));
}
+TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
+ // Discover a free unused port by creating a new UDP socket, binding it
+ // recording the just bound port and closing it. This is not guaranteed as it
+ // can still race with other port UDP sockets trying to bind a port at the
+ // same time.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ socklen_t addrlen = sizeof(addr_storage);
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds());
+ ASSERT_THAT(getsockname(s.get(), addr, &addrlen), SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr_storage), 0);
+ ASSERT_THAT(close(s.release()), SyscallSucceeds());
+
+ // Now connect to the port that we just released. This should generate an
+ // ECONNREFUSED error.
+ ASSERT_THAT(connect(sock_.get(), addr, addrlen_), SyscallSucceeds());
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+ // Send from sock_ to an unbound port.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Now verify that we got an ICMP error back of ECONNREFUSED.
+ int err;
+ socklen_t optlen = sizeof(err);
+ ASSERT_THAT(getsockopt(sock_.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(err, ECONNREFUSED);
+ ASSERT_EQ(optlen, sizeof(err));
+}
+
+TEST_P(UdpSocketTest, ConnectSimultaneousWriteToInvalidPort) {
+ // Discover a free unused port by creating a new UDP socket, binding it
+ // recording the just bound port and closing it. This is not guaranteed as it
+ // can still race with other port UDP sockets trying to bind a port at the
+ // same time.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ socklen_t addrlen = sizeof(addr_storage);
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds());
+ ASSERT_THAT(getsockname(s.get(), addr, &addrlen), SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr_storage), 0);
+ ASSERT_THAT(close(s.release()), SyscallSucceeds());
+
+ // Now connect to the port that we just released.
+ ScopedThread t([&] {
+ ASSERT_THAT(connect(sock_.get(), addr, addrlen_), SyscallSucceeds());
+ });
+
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+ // Send from sock_ to an unbound port.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ t.Join();
+}
+
TEST_P(UdpSocketTest, ReceiveAfterConnect) {
ASSERT_NO_ERRNO(BindLoopback());
ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 3a7de427f..396785e16 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -14,49 +14,81 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Make hacks.
-EMPTY :=
-SPACE := $(EMPTY) $(EMPTY)
+##
+## Docker options.
+##
+## This file supports targets that wrap bazel in a running Docker
+## container to simplify development. Some options are available to
+## control the behavior of this container:
+##
+## USER - The in-container user.
+## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests).
+## DOCKER_NAME - The container name (default: gvisor-bazel-HASH).
+## DOCKER_PRIVILEGED - Docker privileged flags (default: --privileged).
+## BAZEL_CACHE - The bazel cache directory (default: detected).
+## GCLOUD_CONFIG - The gcloud config directory (detect: detected).
+## DOCKER_SOCKET - The Docker socket (default: detected).
+##
+## To opt out of these wrappers, set DOCKER_BUILD=false.
+DOCKER_BUILD := true
+ifeq ($(DOCKER_BUILD),true)
+-include bazel-server
+endif
# See base Makefile.
-SHELL=/bin/bash -o pipefail
BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
- git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
- xargs -n 1 basename 2>/dev/null)
+ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
+ xargs -n 1 basename 2>/dev/null)
BUILD_ROOTS := bazel-bin/ bazel-out/
# Bazel container configuration (see below).
USER := $(shell whoami)
HASH := $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8)
-BUILDER_BASE := gvisor.dev/images/default
-BUILDER_IMAGE := gvisor.dev/images/builder
-BUILDER_NAME := gvisor-builder-$(HASH)
-DOCKER_NAME := gvisor-bazel-$(HASH)
+BUILDER_NAME := gvisor-builder-$(HASH)-$(ARCH)
+DOCKER_NAME := gvisor-bazel-$(HASH)-$(ARCH)
DOCKER_PRIVILEGED := --privileged
BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/)
GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/)
DOCKER_SOCKET := /var/run/docker.sock
-DOCKER_CONFIG := /etc/docker/daemon.json
+DOCKER_CONFIG := /etc/docker
-# Bazel flags.
-BAZEL := bazel $(STARTUP_OPTIONS)
-OPTIONS += --color=no --curses=no
+##
+## Bazel helpers.
+##
+## Bazel will be run with standard flags. You can specify the following flags
+## to control which flags are passed:
+##
+## STARTUP_OPTIONS - Startup options passed to Bazel.
+## BAZEL_CONFIG - A bazel config file.
+##
+STARTUP_OPTIONS :=
+BAZEL_CONFIG :=
+BAZEL := bazel $(STARTUP_OPTIONS)
+BASE_OPTIONS := --color=no --curses=no
+ifneq (,$(BAZEL_CONFIG))
+BASE_OPTIONS += --config=$(BAZEL_CONFIG)
+endif
+TEST_OPTIONS := $(BASE_OPTIONS) \
+ --test_output=errors \
+ --keep_going \
+ --verbose_failures=true \
+ --build_event_json_file=.build_events.json
# Basic options.
UID := $(shell id -u ${USER})
GID := $(shell id -g ${USER})
USERADD_OPTIONS :=
-FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS)
-FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID)
-FULL_DOCKER_RUN_OPTIONS += --entrypoint ""
-FULL_DOCKER_RUN_OPTIONS += --init
-FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)"
-FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)"
-FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp"
-FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID)
-FULL_DOCKER_EXEC_OPTIONS += --interactive
-ifeq (true,$(shell [[ -t 0 ]] && echo true))
-FULL_DOCKER_EXEC_OPTIONS += --tty
+DOCKER_RUN_OPTIONS :=
+DOCKER_RUN_OPTIONS += --user $(UID):$(GID)
+DOCKER_RUN_OPTIONS += --entrypoint ""
+DOCKER_RUN_OPTIONS += --init
+DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)"
+DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)"
+DOCKER_RUN_OPTIONS += -v "/tmp:/tmp"
+DOCKER_EXEC_OPTIONS := --user $(UID):$(GID)
+DOCKER_EXEC_OPTIONS += --interactive
+ifeq (true,$(shell test -t 0 && echo true))
+DOCKER_EXEC_OPTIONS += --tty
endif
# Add basic UID/GID options.
@@ -72,7 +104,7 @@ endif
# out of disk space.
ifneq ($(UID),0)
USERADD_DOCKER += useradd -l --uid $(UID) --non-unique --no-create-home \
- --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) &&
+ --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) &&
endif
ifneq ($(GID),0)
GROUPADD_DOCKER += groupadd --gid $(GID) --non-unique $(USER) &&
@@ -80,126 +112,110 @@ endif
# Add docker passthrough options.
ifneq ($(DOCKER_PRIVILEGED),)
-FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
-# TODO(gvisor.dev/issue/1624): Remove docker config volume. This is required
-# temporarily for checking VFS1 vs VFS2 by some tests.
-FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)"
-FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED)
-FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED)
+DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
+DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)"
+DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED)
+DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED)
DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET))
ifneq ($(GID),$(DOCKER_GROUP))
USERADD_OPTIONS += --groups $(DOCKER_GROUP)
GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) &&
-FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP)
+DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP)
endif
endif
# Add KVM passthrough options.
ifneq (,$(wildcard /dev/kvm))
-FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm
+DOCKER_RUN_OPTIONS += --device=/dev/kvm
KVM_GROUP := $(shell stat -c '%g' /dev/kvm)
ifneq ($(GID),$(KVM_GROUP))
USERADD_OPTIONS += --groups $(KVM_GROUP)
GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) &&
-FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP)
+DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP)
endif
endif
-# Load the appropriate config.
-ifneq (,$(BAZEL_CONFIG))
-OPTIONS += --config=$(BAZEL_CONFIG)
+# Top-level functions.
+#
+# This command runs a bazel server, and the container sticks around
+# until the bazel server exits. This should ensure that it does not
+# exit in the middle of running a build, but also it won't stick around
+# forever. The build commands wrap around an appropriate exec into the
+# container in order to perform work via the bazel client.
+ifeq ($(DOCKER_BUILD),true)
+wrapper = docker exec $(DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(1)
+else
+wrapper = $(1)
endif
-bazel-image: load-default
- @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi
- docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \
- $(BUILDER_BASE) \
- sh -c "$(GROUPADD_DOCKER) \
- $(USERADD_DOCKER) \
- if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi"
- docker commit $(BUILDER_NAME) $(BUILDER_IMAGE)
- @docker rm -f $(BUILDER_NAME)
-.PHONY: bazel-image
-
-##
-## Bazel helpers.
-##
-## This file supports targets that wrap bazel in a running Docker
-## container to simplify development. Some options are available to
-## control the behavior of this container:
-## USER - The in-container user.
-## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests).
-## DOCKER_NAME - The container name (default: gvisor-bazel-HASH).
-## BAZEL_CACHE - The bazel cache directory (default: detected).
-## GCLOUD_CONFIG - The gcloud config directory (detect: detected).
-## DOCKER_SOCKET - The Docker socket (default: detected).
-##
-bazel-server-start: bazel-image ## Starts the bazel server.
- @mkdir -p $(BAZEL_CACHE)
- @mkdir -p $(GCLOUD_CONFIG)
- @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi
- # This command runs a bazel server, and the container sticks around
- # until the bazel server exits. This should ensure that it does not
- # exit in the middle of running a build, but also it won't stick around
- # forever. The build commands wrap around an appropriate exec into the
- # container in order to perform work via the bazel client.
- docker run -d --rm --name $(DOCKER_NAME) \
- -v "$(CURDIR):$(CURDIR)" \
- --workdir "$(CURDIR)" \
- $(FULL_DOCKER_RUN_OPTIONS) \
- $(BUILDER_IMAGE) \
- sh -c "tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null"
-.PHONY: bazel-server-start
-
bazel-shutdown: ## Shuts down a running bazel server.
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \
- rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]]
+ @$(call wrapper,$(BAZEL) shutdown)
.PHONY: bazel-shutdown
bazel-alias: ## Emits an alias that can be used within the shell.
- @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'"
+ @echo "alias bazel='$(call wrapper,$(BAZEL))'"
.PHONY: bazel-alias
-bazel-server: ## Ensures that the server exists. Used as an internal target.
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true || $(MAKE) bazel-server-start
-.PHONY: bazel-server
+bazel-image: load-default ## Ensures that the local builder exists.
+ @$(call header,DOCKER BUILD)
+ @docker rm -f $(BUILDER_NAME) 2>/dev/null || true
+ @docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) gvisor.dev/images/default \
+ sh -c "$(GROUPADD_DOCKER) $(USERADD_DOCKER) if test -e /dev/kvm; then chmod a+rw /dev/kvm; fi"
+ @docker commit $(BUILDER_NAME) gvisor.dev/images/builder
+.PHONY: bazel-image
-build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) "$(TARGETS)"'
-
-build_paths = $(build_cmd) 2>&1 \
- | tee /proc/self/fd/2 \
- | grep -A1 -E '^Target' \
- | grep -E '^ ($(subst $(SPACE),|,$(BUILD_ROOTS)))' \
- | sed "s/ /\n/g" \
- | strings -n 10 \
- | awk '{$$1=$$1};1' \
- | xargs -n 1 -I {} readlink -f "{}" \
- | xargs -n 1 -I {} sh -c "$(1)"
-
-build: bazel-server
- @$(call build_cmd)
-.PHONY: build
-
-copy: bazel-server
-ifeq (,$(DESTINATION))
- $(error Destination not provided.)
+ifneq (true,$(shell $(wrapper echo true)))
+bazel-server: bazel-image ## Ensures that the server exists.
+ @$(call header,DOCKER RUN)
+ @docker rm -f $(DOCKER_NAME) 2>/dev/null || true
+ @mkdir -p $(GCLOUD_CONFIG)
+ @mkdir -p $(BAZEL_CACHE)
+ @docker run -d --rm --name $(DOCKER_NAME) \
+ -v "$(CURDIR):$(CURDIR)" \
+ --workdir "$(CURDIR)" \
+ $(DOCKER_RUN_OPTIONS) \
+ gvisor.dev/images/builder \
+ sh -c "set -x; tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null"
+else
+bazel-server:
+ @
endif
- @$(call build_paths,cp -fa {} $(DESTINATION))
-
-run: bazel-server
- @$(call build_paths,{} $(ARGS))
-.PHONY: run
-
-sudo: bazel-server
- @$(call build_paths,sudo -E {} $(ARGS))
-.PHONY: sudo
-
-test: OPTIONS += --test_output=errors --keep_going --verbose_failures=true
-test: bazel-server
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS)
-.PHONY: test
+.PHONY: bazel-server
-query:
- @$(MAKE) bazel-server >&2 # If we need to start, ensure stdout is not polluted.
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) query $(OPTIONS) "$(TARGETS)" 2>/dev/null'
-.PHONY: query
+# build_paths extracts the built binary from the bazel stderr output.
+#
+# This could be alternately done by parsing the bazel build event stream, but
+# this is a complex schema, and begs the question: what will build the thing
+# that parses the output? Bazel? Do we need a separate bootstrapping build
+# command here? Yikes, let's just stick with the ugly shell pipeline.
+#
+# The last line is used to prevent terminal shenanigans.
+build_paths = \
+ $(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(1)) 2>&1 \
+ | tee /proc/self/fd/2 \
+ | grep -A1 -E '^Target' \
+ | grep -E '^ ($(subst $(SPACE),|,$(BUILD_ROOTS)))' \
+ | sed "s/ /\n/g" \
+ | strings -n 10 \
+ | awk '{$$1=$$1};1' \
+ | xargs -n 1 -I {} readlink -f "{}" \
+ | xargs -n 1 -I {} bash -c 'set -xeuo pipefail; $(2)'
+
+clean = $(call header,CLEAN) && $(call wrapper,$(BAZEL) clean)
+build = $(call header,BUILD $(1)) && $(call build_paths,$(1),echo {})
+copy = $(call header,COPY $(1) $(2)) && $(call build_paths,$(1),cp -fa {} $(2))
+run = $(call header,RUN $(1) $(2)) && $(call build_paths,$(1),{} $(2))
+sudo = $(call header,SUDO $(1) $(2)) && $(call build_paths,$(1),sudo -E {} $(2))
+test = $(call header,TEST $(1)) && $(call wrapper,$(BAZEL) test $(TEST_OPTIONS) $(1))
+
+clean: ## Cleans the bazel cache.
+ @$(call clean)
+.PHONY: clean
+
+testlogs: ## Returns the most recent set of test logs.
+ @if test -f .build_events.json; then \
+ cat .build_events.json | jq -r \
+ 'select(.testSummary?.overallStatus? | tostring | test("(FAILED|FLAKY|TIMEOUT)")) | .testSummary.failed | .[] | .uri' | \
+ awk -Ffile:// '{print $$2;}'; \
+ fi
+.PHONY: testlogs
diff --git a/tools/bazel_gazelle.patch b/tools/bazel_gazelle.patch
new file mode 100644
index 000000000..e35f38933
--- /dev/null
+++ b/tools/bazel_gazelle.patch
@@ -0,0 +1,24 @@
+diff -r -u2 a/language/go/resolve.go b/language/go/resolve.go
+--- a/language/go/resolve.go 2020-10-02 14:22:18.000000000 -0700
++++ b/language/go/resolve.go 2020-11-17 19:40:59.770648029 -0800
+@@ -20,5 +20,4 @@
+ "fmt"
+ "go/build"
+- "log"
+ "path"
+ "regexp"
+@@ -80,5 +79,5 @@
+ resolve = ResolveGo
+ }
+- deps, errs := imports.Map(func(imp string) (string, error) {
++ deps, _ := imports.Map(func(imp string) (string, error) {
+ l, err := resolve(c, ix, rc, imp, from)
+ if err == skipImportError {
+@@ -95,7 +94,4 @@
+ return l.String(), nil
+ })
+- for _, err := range errs {
+- log.Print(err)
+- }
+ if !deps.IsEmpty() {
+ if r.Kind() == "go_proto_library" {
diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD
index 27e85a75e..97c7cb45f 100644
--- a/tools/bazeldefs/BUILD
+++ b/tools/bazeldefs/BUILD
@@ -58,3 +58,21 @@ bzl_library(
srcs = ["defs.bzl"],
visibility = ["//visibility:private"],
)
+
+config_setting(
+ name = "linux_arm64_cross",
+ values = {
+ "cpu": "aarch64",
+ "host_cpu": "k8",
+ },
+ visibility = ["//visibility:private"],
+)
+
+config_setting(
+ name = "linux_amd64_cross",
+ values = {
+ "cpu": "k8",
+ "host_cpu": "aarch64",
+ },
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index c2f94bb9c..279a38fed 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -7,6 +7,8 @@ build_test = _build_test
bzl_library = _bzl_library
rbe_platform = native.platform
rbe_toolchain = native.toolchain
+more_shards = 4
+most_shards = 8
def short_path(path):
return path
@@ -37,3 +39,44 @@ def default_net_util():
def coreutil():
return [] # Nothing needed.
+
+def select_native_vs_cross(native = [], amd64 = [], arm64 = [], cross = []):
+ values = {
+ "//tools/bazeldefs:linux_arm64_cross": arm64 + cross,
+ "//tools/bazeldefs:linux_amd64_cross": amd64 + cross,
+ "//conditions:default": native,
+ }
+ return select(values)
+
+def arch_genrule(name, srcs, outs, cmd, tools):
+ """Runs a gen command on the target architecture.
+
+ If the target architecture isn't match the host architecture, it will build
+ a command for the target architecture and run it via qemu.
+
+ The native genrule runs the command on the host architecture.
+
+ Args:
+ name: name of generated target.
+ srcs: A list of inputs for this rule.
+ cmd: The command to run. It has to contain " QEMU " before executed binaries.
+ outs: A list of files generated by this rule.
+ tools: A list of tool dependencies for this rule.
+ """
+ qemu_arm64 = "qemu-aarch64-static"
+ qemu_amd64 = "qemu-x86_64-static"
+ srcs = select_native_vs_cross(
+ cross = srcs + tools,
+ native = srcs,
+ )
+ tools = select_native_vs_cross(
+ cross = [],
+ native = tools,
+ )
+ cmd = select_native_vs_cross(
+ arm64 = cmd.replace("QEMU", qemu_arm64),
+ amd64 = cmd.replace("QEMU", qemu_amd64),
+ native = cmd.replace("QEMU", ""),
+ cross = "",
+ )
+ native.genrule(name = name, srcs = srcs, outs = outs, cmd = cmd, tools = tools)
diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl
index 661c9727e..bcd8cffe7 100644
--- a/tools/bazeldefs/go.bzl
+++ b/tools/bazeldefs/go.bzl
@@ -28,7 +28,7 @@ def go_proto_library(name, **kwargs):
def go_grpc_and_proto_libraries(name, **kwargs):
_go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
-def go_binary(name, static = False, pure = False, x_defs = None, **kwargs):
+def go_binary(name, static = False, pure = False, x_defs = None, system_malloc = False, **kwargs):
"""Build a go binary.
Args:
@@ -52,7 +52,7 @@ def go_importpath(target):
"""Returns the importpath for the target."""
return target[GoLibrary].importpath
-def go_library(name, **kwargs):
+def go_library(name, arch_deps = [], **kwargs):
_go_library(
name = name,
importpath = "gvisor.dev/gvisor/" + native.package_name(),
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
index 544af3876..a4ca93ec2 100644
--- a/tools/bigquery/bigquery.go
+++ b/tools/bigquery/bigquery.go
@@ -21,6 +21,7 @@ package bigquery
import (
"context"
"fmt"
+ "strconv"
"strings"
"time"
@@ -109,6 +110,12 @@ func NewBenchmark(name string, iters int) *Benchmark {
return &Benchmark{
Name: name,
Metric: make([]*Metric, 0),
+ Condition: []*Condition{
+ {
+ Name: "iterations",
+ Value: strconv.Itoa(iters),
+ },
+ },
}
}
diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go
index 27991649f..f46eba39b 100644
--- a/tools/checkescape/test1/test1.go
+++ b/tools/checkescape/test1/test1.go
@@ -36,17 +36,20 @@ func (t Type) Foo() {
fmt.Printf("%v", t) // Never executed.
}
+// InterfaceFunction is passed an interface argument.
// +checkescape:all,hard
//go:nosplit
func InterfaceFunction(i Interface) {
// Do nothing; exported for tests.
}
+// TypeFunction is passed a concrete pointer argument.
// +checkesacape:all,hard
//go:nosplit
func TypeFunction(t *Type) {
}
+// BuiltinMap creates a new map.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -61,7 +64,8 @@ func builtinMapRec(x int) map[string]bool {
return BuiltinMap(x)
}
-// +temustescapestescape:local,builtin
+// BuiltinClosure returns a closure around x.
+// +mustescape:local,builtin
//go:noinline
//go:nosplit
func BuiltinClosure(x int) func() {
@@ -77,6 +81,7 @@ func builtinClosureRec(x int) func() {
return BuiltinClosure(x)
}
+// BuiltinMakeSlice makes a new slice.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -91,6 +96,7 @@ func builtinMakeSliceRec(x int) []byte {
return BuiltinMakeSlice(x)
}
+// BuiltinAppend calls append on a slice.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -105,6 +111,7 @@ func builtinAppendRec() []byte {
return BuiltinAppend(nil)
}
+// BuiltinChan makes a channel.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -119,6 +126,7 @@ func builtinChanRec() chan int {
return BuiltinChan()
}
+// Heap performs an explicit heap allocation.
// +mustescape:local,heap
//go:noinline
//go:nosplit
@@ -134,6 +142,7 @@ func heapRec() *Type {
return Heap()
}
+// Dispatch dispatches via an interface.
// +mustescape:local,interface
//go:noinline
//go:nosplit
@@ -148,6 +157,7 @@ func dispatchRec(i Interface) {
Dispatch(i)
}
+// Dynamic invokes a dynamic function.
// +mustescape:local,dynamic
//go:noinline
//go:nosplit
@@ -167,6 +177,7 @@ func dynamicRec(f func()) {
func internalFunc() {
}
+// Split includes a guaranteed stack split.
// +mustescape:local,stack
//go:noinline
func Split() {
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 2c8129e7e..54d756e55 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -8,7 +8,7 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
load("//tools/nogo:defs.bzl", "nogo_test")
-load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _proto_library = "proto_library", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path")
+load("//tools/bazeldefs:defs.bzl", _arch_genrule = "arch_genrule", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path")
load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option")
load("//tools/bazeldefs:go.bzl", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos")
load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
@@ -16,6 +16,7 @@ load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform",
load("//tools/bazeldefs:tags.bzl", "go_suffixes")
# Core rules.
+arch_genrule = _arch_genrule
build_test = _build_test
bzl_library = _bzl_library
default_installer = _default_installer
@@ -26,6 +27,8 @@ short_path = _short_path
rbe_platform = _rbe_platform
rbe_toolchain = _rbe_toolchain
coreutil = _coreutil
+more_shards = _more_shards
+most_shards = _most_shards
# C++ rules.
cc_binary = _cc_binary
@@ -182,6 +185,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
name + suffix + "_state_autogen.go"
for suffix in state_sets.keys()
]
+
if "//pkg/state" not in all_deps:
all_deps = all_deps + ["//pkg/state"]
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index 71d036b12..ca07246a6 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -39,7 +39,7 @@ declare tmp_dir
tmp_dir=$(mktemp -d)
readonly tmp_dir
finish() {
- cd # Leave tmp_dir.
+ cd / # Leave tmp_dir.
rm -rf "${tmp_dir}"
}
trap finish EXIT
@@ -90,7 +90,7 @@ find . -type f -exec chmod 0644 {} \;
find . -type d -exec chmod 0755 {} \;
# Sync the entire gopath_dir.
-rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" .
+rsync --recursive --delete --exclude .git -L "${gopath_dir}/" .
# Add additional files.
for file in "${othersrc[@]}"; do
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
index ad97208a8..50e2546bf 100644
--- a/tools/go_generics/defs.bzl
+++ b/tools/go_generics/defs.bzl
@@ -67,7 +67,7 @@ def _go_template_instance_impl(ctx):
# Check that all defined types are expected by the template.
for t in ctx.attr.types:
if (t not in info.types) and (t not in info.opt_types):
- fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label))
+ fail("Type %s is not a parameter to %s" % (t, ctx.attr.template.label))
# Check that all required consts are defined.
for t in info.consts:
@@ -77,7 +77,7 @@ def _go_template_instance_impl(ctx):
# Check that all defined consts are expected by the template.
for t in ctx.attr.consts:
if (t not in info.consts) and (t not in info.opt_consts):
- fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label))
+ fail("Const %s is not a parameter to %s" % (t, ctx.attr.template.label))
# Build the argument list.
args = ["-i=%s" % info.template.path, "-o=%s" % output.path]
diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go
index 0860ca9db..30584006c 100644
--- a/tools/go_generics/generics.go
+++ b/tools/go_generics/generics.go
@@ -223,7 +223,7 @@ func main() {
} else {
switch kind {
case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction:
- if ident.Name != "_" {
+ if ident.Name != "_" && !(ident.Name == "init" && kind == globals.KindFunction) {
ident.Name = *prefix + ident.Name + *suffix
}
case globals.KindTag:
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 4a53d25be..6f41b1b79 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -213,10 +213,11 @@ type sliceAPI struct {
type marshallableType struct {
spec *ast.TypeSpec
slice *sliceAPI
+ recv string
}
-func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType {
- mt := marshallableType{
+func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType {
+ mt := &marshallableType{
spec: spec,
slice: nil,
}
@@ -261,12 +262,31 @@ func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.Ty
// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType {
- var types []marshallableType
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType {
+ recv := make(map[string]string) // Type name to recevier name.
+ types := make(map[*ast.TypeSpec]*marshallableType)
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
// Type declaration?
if !ok || gdecl.Tok != token.TYPE {
+ // Is this a function declaration? We remember receiver names.
+ d, ok := decl.(*ast.FuncDecl)
+ if ok && d.Recv != nil && len(d.Recv.List) == 1 {
+ // Accept concrete methods & pointer methods.
+ ident, ok := d.Recv.List[0].Type.(*ast.Ident)
+ if !ok {
+ var st *ast.StarExpr
+ st, ok = d.Recv.List[0].Type.(*ast.StarExpr)
+ if ok {
+ ident, ok = st.X.(*ast.Ident)
+ }
+ }
+ // The receiver name may be not present.
+ if ok && len(d.Recv.List[0].Names) == 1 {
+ // Recover the type receiver name in this case.
+ recv[ident.Name] = d.Recv.List[0].Names[0].Name
+ }
+ }
debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
continue
}
@@ -305,9 +325,19 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []ma
// don't support it.
abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
- types = append(types, newMarshallableType(f, tagLine, t))
-
+ types[t] = newMarshallableType(f, tagLine, t)
+ }
+ }
+ // Update the types with the last seen receiver. As long as the
+ // receiver name is consistent for the type, then we will generate
+ // code that is still consistent with itself.
+ for t, mt := range types {
+ r, ok := recv[t.Name.Name]
+ if !ok {
+ mt.recv = receiverName(t) // Default.
+ continue
}
+ mt.recv = r // Last seen.
}
return types
}
@@ -345,8 +375,8 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
-func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator {
- i := newInterfaceGenerator(t.spec, fset)
+func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator {
+ i := newInterfaceGenerator(t.spec, t.recv, fset)
switch ty := t.spec.Type.(type) {
case *ast.StructType:
i.validateStruct(t.spec, ty)
@@ -376,8 +406,8 @@ func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interf
// generateOneTestSuite generates a test suite for the automatically generated
// implementations type t.
-func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator {
- i := newTestGenerator(t.spec)
+func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator {
+ i := newTestGenerator(t.spec, t.recv)
i.emitTests(t.slice)
return i
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 36447b86b..65f5ea34d 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -54,10 +54,10 @@ func (g *interfaceGenerator) typeName() string {
}
// newinterfaceGenerator creates a new interface generator.
-func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+func newInterfaceGenerator(t *ast.TypeSpec, r string, fset *token.FileSet) *interfaceGenerator {
g := &interfaceGenerator{
t: t,
- r: receiverName(t),
+ r: r,
f: fset,
is: make(map[string]struct{}),
ms: make(map[string]struct{}),
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 631295373..6cf00843f 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -53,10 +53,10 @@ type testGenerator struct {
decl *importStmt
}
-func newTestGenerator(t *ast.TypeSpec) *testGenerator {
+func newTestGenerator(t *ast.TypeSpec, r string) *testGenerator {
g := &testGenerator{
t: t,
- r: receiverName(t),
+ r: r,
imports: newImportTable(),
}
diff --git a/tools/images.mk b/tools/images.mk
new file mode 100644
index 000000000..46f56bb2c
--- /dev/null
+++ b/tools/images.mk
@@ -0,0 +1,169 @@
+#!/usr/bin/make -f
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+##
+## Docker image targets.
+##
+## Images used by the tests must also be built and available locally.
+## The canonical test targets defined below will automatically load
+## relevant images. These can be loaded or built manually via these
+## targets.
+##
+## (*) Note that you may provide an ARCH parameter in order to build
+## and load images from an alternate archiecture (using qemu). When
+## bazel is run as a server, this has the effect of running an full
+## cross-architecture chain, and can produce cross-compiled binaries.
+##
+
+# ARCH is the architecture used for the build. This may be overriden at the
+# command line in order to perform a cross-build (in a limited capacity).
+ARCH := $(shell uname -m)
+ifneq ($(ARCH),$(shell uname -m))
+DOCKER_PLATFORM_ARGS := --platform=$(ARCH)
+else
+DOCKER_PLATFORM_ARGS :=
+endif
+
+# Note that the image prefixes used here must match the image mangling in
+# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all
+# tests are using locally-defined images (that are consistent and idempotent).
+REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit
+LOCAL_IMAGE_PREFIX ?= gvisor.dev/images
+ALL_IMAGES := $(subst /,_,$(subst images/,,$(shell find images/ -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq)))
+SUB_IMAGES := $(foreach image,$(ALL_IMAGES),$(if $(findstring _,$(image)),$(image),))
+IMAGE_GROUPS := $(sort $(foreach image,$(SUB_IMAGES),$(firstword $(subst _, ,$(image)))))
+
+define expand_group =
+load-$(1): $$(patsubst $(1)_%, load-$(1)_%, $$(filter $(1)_%,$$(ALL_IMAGES)))
+ @
+.PHONY: load-$(1)
+push-$(1): $$(patsubst $(1)_%, push-$(1)_%, $$(filter $(1)_%,$$(ALL_IMAGES)))
+ @
+.PHONY: push-$(1)
+endef
+$(foreach group,$(IMAGE_GROUPS),$(eval $(call expand_group,$(group))))
+
+list-all-images: ## List all images.
+ @for image in $(ALL_IMAGES); do echo $${image}; done
+.PHONY: list-all-images
+
+load-all-images: ## Load all images.
+load-all-images: $(patsubst %,load-%,$(ALL_IMAGES))
+.PHONY: load-all-images
+
+push-all-images: ## Push all images.
+push-all-images: $(patsubst %,push-%,$(ALL_IMAGES))
+.PHONY: push-all-images
+
+# path and dockerfile are used to extract the relevant path and dockerfile
+# (depending on what's available for the given architecture).
+path = images/$(subst _,/,$(1))
+dockerfile = $$(if [ -f "$(call path,$(1))/Dockerfile.$(ARCH)" ]; then echo Dockerfile.$(ARCH); else echo Dockerfile; fi)
+
+# The tag construct is used to memoize the image generated (see README.md).
+# This scheme is used to enable aggressive caching in a central repository, but
+# ensuring that images will always be sourced using the local files.
+tag = $(shell cd images && find $(subst _,/,$(1)) -type f | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16)
+remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH)
+local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1))
+
+# Include all existing images as targets here.
+#
+# Note that we use a _ for the tag separator, instead of :, as the latter is
+# interpreted by Make, unfortunately. tag_expand expands the generic rules to
+# tag-specific targets. These is needed to provide sensible targets for load
+# below, with caching. Basically, if there is a rule generated here, then the
+# load will be skipped. If there is no load generated here, then the default
+# rule for load will kick in.
+#
+# Note that if this rule does not successfully rule, we will simply have
+# additional Docker pull commands that run for all images that are already
+# pulled. No real harm done.
+EXISTING_IMAGES = $(shell docker images --format '{{.Repository}}_{{.Tag}}' | grep -v '<none>')
+define existing_image_rule =
+loaded0_$(1)=load-$$(1): tag-$$(1) # Already available.
+loaded1_$(1)=.PHONY: load-$$(1)
+endef
+$(foreach image, $(EXISTING_IMAGES), $(eval $(call existing_image_rule,$(image))))
+define tag_expand_rule =
+$(eval $(loaded0_$(call remote_image,$(1))_$(call tag,$(1))))
+$(eval $(loaded1_$(call remote_image,$(1))_$(call tag,$(1))))
+endef
+$(foreach image, $(ALL_IMAGES), $(eval $(call tag_expand_rule,$(image))))
+
+# tag tags a local image. This applies both the hash-based tag from above to
+# ensure that caching works as expected, as well as the "latest" tag that is
+# used by the tests.
+local_tag = \
+ docker tag $(call remote_image,$(1)):$(call tag,$(1)) $(call local_image,$(1)):$(call tag,$(1))
+latest_tag = \
+ docker tag $(call local_image,$(1)):$(call tag,$(1)) $(call local_image,$(1))
+tag-%: ## Tag a local image.
+ @$(call header,TAG $*)
+ @$(call local_tag,$*) && $(call latest_tag,$*)
+
+# pull forces the image to be pulled.
+pull = \
+ $(call header,PULL $(1)) && \
+ docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$(1)):$(call tag,$(1)) && \
+ $(call local_tag,$(1)) && \
+ $(call latest_tag,$(1))
+pull-%: register-cross ## Force a repull of the image.
+ @$(call pull,$*)
+
+# rebuild builds the image locally. Only the "remote" tag will be applied. Note
+# we need to explicitly repull the base layer in order to ensure that the
+# architecture is correct. Note that we use the term "rebuild" here to avoid
+# conflicting with the bazel "build" terminology, which is used elsewhere.
+rebuild = \
+ $(call header,REBUILD $(1)) && \
+ (T=$$(mktemp -d) && cp -a $(call path,$(1))/* $$T && \
+ $(foreach image,$(shell grep FROM "$(call path,$(1))/$(call dockerfile,$(1))" 2>/dev/null | cut -d' ' -f2),docker pull $(DOCKER_PLATFORM_ARGS) $(image) &&) \
+ docker build $(DOCKER_PLATFORM_ARGS) \
+ -f "$$T/$(call dockerfile,$(1))" \
+ -t "$(call remote_image,$(1)):$(call tag,$(1))" \
+ $$T && \
+ rm -rf $$T) && \
+ $(call local_tag,$(1)) && \
+ $(call latest_tag,$(1))
+rebuild-%: register-cross ## Force rebuild an image locally.
+ @$(call rebuild,$*)
+
+# load will either pull the "remote" or build it locally. This is the preferred
+# entrypoint, as it should never fail. The local tag should always be set after
+# this returns (either by the pull or the build).
+load-%: register-cross ## Pull or build an image locally.
+ @($(call pull,$*)) || ($(call rebuild,$*))
+
+# push pushes the remote image, after either pulling (to validate that the tag
+# already exists) or building manually. Note that this generic rule will match
+# the fully-expanded remote image tag.
+push-%: load-% ## Push a given image.
+ @docker push $(call remote_image,$*):$(call tag,$*)
+
+# register-cross registers the necessary qemu binaries for cross-compilation.
+# This may be used by any target that may execute containers that are not the
+# native format. Note that this will only apply on the first execution.
+register-cross:
+ifneq ($(ARCH),$(shell uname -m))
+ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*))
+ @docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes
+else
+ @
+endif
+else
+ @
+endif
diff --git a/tools/installers/BUILD b/tools/installers/BUILD
index 13d3cc5e0..bbf3c1f85 100644
--- a/tools/installers/BUILD
+++ b/tools/installers/BUILD
@@ -1,4 +1,4 @@
-# Installers for use by the tools/vm_test rules.
+# Installers for use by top-level scripts.
package(
default_visibility = ["//:sandbox"],
@@ -14,14 +14,6 @@ sh_binary(
)
sh_binary(
- name = "images",
- srcs = ["images.sh"],
- data = [
- "//images",
- ],
-)
-
-sh_binary(
name = "master",
srcs = ["master.sh"],
)
diff --git a/tools/installers/containerd.sh b/tools/installers/containerd.sh
index 6b7bb261c..d28549734 100755
--- a/tools/installers/containerd.sh
+++ b/tools/installers/containerd.sh
@@ -16,7 +16,7 @@
set -xeo pipefail
-declare -r CONTAINERD_VERSION=${CONTAINERD_VERSION:-1.3.0}
+declare -r CONTAINERD_VERSION=${1:-1.3.0}
declare -r CONTAINERD_MAJOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $1; }')"
declare -r CONTAINERD_MINOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $2; }')"
@@ -43,10 +43,23 @@ install_helper() {
make install)
}
+# Figure out were btrfs headers are.
+#
+# Ubuntu 16.04 has only btrfs-tools, while 18.04 has a transitional package,
+# and later versions no longer have the transitional package.
+source /etc/os-release
+declare BTRFS_DEV
+if [[ "${VERSION_ID%.*}" -le "18" ]]; then
+ BTRFS_DEV="btrfs-tools"
+else
+ BTRFS_DEV="libbtrfs-dev"
+fi
+readonly BTRFS_DEV
+
# Install dependencies for the crictl tests.
while true; do
if (apt-get update && apt-get install -y \
- btrfs-tools \
+ "${BTRFS_DEV}" \
libseccomp-dev); then
break
fi
diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go
index 9cf41b3b0..8be38ca6d 100644
--- a/tools/nogo/filter/main.go
+++ b/tools/nogo/filter/main.go
@@ -16,6 +16,7 @@
package main
import (
+ "bytes"
"flag"
"fmt"
"io/ioutil"
@@ -76,12 +77,14 @@ func main() {
log.Fatalf("unable to read %s: %v", filename, err)
}
var newConfig nogo.Config // For current file.
- if err := yaml.Unmarshal(content, &newConfig); err != nil {
+ dec := yaml.NewDecoder(bytes.NewBuffer(content))
+ dec.SetStrict(true)
+ if err := dec.Decode(&newConfig); err != nil {
log.Fatalf("unable to decode %s: %v", filename, err)
}
config.Merge(&newConfig)
if showConfig {
- bytes, err := yaml.Marshal(&newConfig)
+ content, err := yaml.Marshal(&newConfig)
if err != nil {
log.Fatalf("error marshalling config: %v", err)
}
@@ -89,7 +92,7 @@ func main() {
if err != nil {
log.Fatalf("error marshalling config: %v", err)
}
- fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(bytes))
+ fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(content))
fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes))
}
}
diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go
index f0737d46b..39a13b4af 100644
--- a/tools/parsers/go_parser_test.go
+++ b/tools/parsers/go_parser_test.go
@@ -34,6 +34,10 @@ func TestParseLine(t *testing.T) {
Name: "BenchmarkIperf",
Condition: []*bigquery.Condition{
{
+ Name: "iterations",
+ Value: "1",
+ },
+ {
Name: "GOMAXPROCS",
Value: "6",
},
@@ -63,6 +67,10 @@ func TestParseLine(t *testing.T) {
Name: "BenchmarkRuby",
Condition: []*bigquery.Condition{
{
+ Name: "iterations",
+ Value: "1",
+ },
+ {
Name: "GOMAXPROCS",
Value: "6",
},
@@ -100,12 +108,14 @@ func TestParseLine(t *testing.T) {
}
if !cmp.Equal(tc.want, got, nil) {
- for _, c := range got.Condition {
- t.Logf("Cond: %+v", c)
+ for i := range got.Condition {
+ t.Logf("Metric: want: %+v got:%+v", got.Condition[i], tc.want.Condition[i])
}
- for _, m := range got.Metric {
- t.Logf("Metric: %+v", m)
+
+ for i := range got.Metric {
+ t.Logf("Metric: want: %+v got:%+v", got.Metric[i], tc.want.Metric[i])
}
+
t.Fatalf("Compare failed want: %+v got: %+v", tc.want, got)
}
})
@@ -131,7 +141,7 @@ func TestParseOutput(t *testing.T) {
`,
numBenchmarks: 2,
numMetrics: 1,
- numConditions: 1,
+ numConditions: 2,
},
{
name: "Ruby",
@@ -142,7 +152,7 @@ BenchmarkRuby/server_threads.5
BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 465 requests_per_second.QPS`,
numBenchmarks: 2,
numMetrics: 3,
- numConditions: 2,
+ numConditions: 3,
},
}
diff --git a/tools/vm/BUILD b/tools/vm/BUILD
deleted file mode 100644
index d95ca6c63..000000000
--- a/tools/vm/BUILD
+++ /dev/null
@@ -1,63 +0,0 @@
-load("//tools:defs.bzl", "bzl_library", "cc_binary", "gtest")
-load("//tools/vm:defs.bzl", "vm_image", "vm_test")
-
-package(
- default_visibility = ["//:sandbox"],
- licenses = ["notice"],
-)
-
-sh_binary(
- name = "zone",
- srcs = ["zone.sh"],
-)
-
-sh_binary(
- name = "builder",
- srcs = ["build.sh"],
-)
-
-sh_binary(
- name = "executer",
- srcs = ["execute.sh"],
-)
-
-cc_binary(
- name = "test",
- testonly = 1,
- srcs = ["test.cc"],
- linkstatic = 1,
- deps = [
- gtest,
- "//test/util:test_main",
- ],
-)
-
-vm_image(
- name = "ubuntu1604",
- family = "ubuntu-1604-lts",
- project = "ubuntu-os-cloud",
- scripts = [
- "//tools/vm/ubuntu1604",
- ],
-)
-
-vm_image(
- name = "ubuntu1804",
- family = "ubuntu-1804-lts",
- project = "ubuntu-os-cloud",
- scripts = [
- "//tools/vm/ubuntu1804",
- ],
-)
-
-vm_test(
- name = "vm_test",
- shard_count = 2,
- targets = [":test"],
-)
-
-bzl_library(
- name = "defs_bzl",
- srcs = ["defs.bzl"],
- visibility = ["//visibility:private"],
-)
diff --git a/tools/vm/README.md b/tools/vm/README.md
deleted file mode 100644
index 1e9859e66..000000000
--- a/tools/vm/README.md
+++ /dev/null
@@ -1,48 +0,0 @@
-# VM Images & Tests
-
-All commands in this directory require the `gcloud` project to be set.
-
-For example: `gcloud config set project gvisor-kokoro-testing`.
-
-Images can be generated by using the `vm_image` rule. This rule will generate a
-binary target that builds an image in an idempotent way, and can be referenced
-from other rules.
-
-For example:
-
-```
-vm_image(
- name = "ubuntu",
- project = "ubuntu-1604-lts",
- family = "ubuntu-os-cloud",
- scripts = [
- "script.sh",
- "other.sh",
- ],
-)
-```
-
-These images can be built manually by executing the target. The output on
-`stdout` will be the image id (in the current project).
-
-For example:
-
-```
-$ bazel build :ubuntu
-```
-
-Images are always named per the hash of all the hermetic input scripts. This
-allows images to be memoized quickly and easily.
-
-The `vm_test` rule can be used to execute a command remotely. This is still
-under development however, and will likely change over time.
-
-For example:
-
-```
-vm_test(
- name = "mycommand",
- image = ":ubuntu",
- targets = [":test"],
-)
-```
diff --git a/tools/vm/build.sh b/tools/vm/build.sh
deleted file mode 100755
index 752b2b77b..000000000
--- a/tools/vm/build.sh
+++ /dev/null
@@ -1,117 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# This script is responsible for building a new GCP image that: 1) has nested
-# virtualization enabled, and 2) has been completely set up with the
-# image_setup.sh script. This script should be idempotent, as we memoize the
-# setup script with a hash and check for that name.
-
-set -eou pipefail
-
-# Parameters.
-declare -r USERNAME=${USERNAME:-test}
-declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
-declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
-declare -r ZONE=${ZONE:-us-central1-f}
-
-# Random names.
-declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
-declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
-declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
-
-# Hash inputs in order to memoize the produced image.
-declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
-declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH}
-
-# Does the image already exist? Skip the build.
-declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
-if ! [[ -z "${existing}" ]]; then
- echo "${existing}"
- exit 0
-fi
-
-# Standard arguments (applies only on script execution).
-declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
-
-# gcloud has path errors; is this a result of being a genrule?
-export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin}
-
-# Start a unique instance. Note that this instance will have a unique persistent
-# disk as it's boot disk with the same name as the instance.
-(set -x; gcloud compute instances create \
- --quiet \
- --image-project "${IMAGE_PROJECT}" \
- --image-family "${IMAGE_FAMILY}" \
- --boot-disk-size "200GB" \
- --zone "${ZONE}" \
- "${INSTANCE_NAME}" >/dev/null)
-function cleanup {
- (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}")
-}
-trap cleanup EXIT
-
-# Wait for the instance to become available (up to 5 minutes).
-echo -n "Waiting for ${INSTANCE_NAME}" >&2
-declare timeout=300
-declare success=0
-declare internal=""
-declare -r start=$(date +%s)
-declare -r end=$((${start}+${timeout}))
-while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
- echo -n "." >&2
- if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
- success=$((${success}+1))
- elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
- success=$((${success}+1))
- internal="--internal-ip"
- fi
-done
-
-if [[ "${success}" -eq "0" ]]; then
- echo "connect timed out after ${timeout} seconds." >&2
- exit 1
-else
- echo "done." >&2
-fi
-
-# Run the install scripts provided.
-for arg; do
- (set -x; gcloud compute ssh ${internal} \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- sudo bash - <"${arg}" >/dev/null)
-done
-
-# Stop the instance; required before creating an image.
-(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null)
-
-# Create a snapshot of the instance disk.
-(set -x; gcloud compute disks snapshot \
- --quiet \
- --zone "${ZONE}" \
- --snapshot-names="${SNAPSHOT_NAME}" \
- "${INSTANCE_NAME}" >/dev/null)
-
-# Create the disk image.
-(set -x; gcloud compute images create \
- --quiet \
- --source-snapshot="${SNAPSHOT_NAME}" \
- --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
- "${IMAGE_NAME}" >/dev/null)
-
-# Finish up.
-echo "${IMAGE_NAME}"
diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl
deleted file mode 100644
index 9af5ad3b4..000000000
--- a/tools/vm/defs.bzl
+++ /dev/null
@@ -1,202 +0,0 @@
-"""Image configuration. See README.md."""
-
-load("//tools:defs.bzl", "default_installer")
-
-# vm_image_builder is a rule that will construct a shell script that actually
-# generates a given VM image. Note that this does not _run_ the shell script
-# (although it can be run manually). It will be run manually during generation
-# of the vm_image target itself. This level of indirection is used so that the
-# build system itself only runs the builder once when multiple targets depend
-# on it, avoiding a set of races and conflicts.
-def _vm_image_builder_impl(ctx):
- # Generate a binary that actually builds the image.
- builder = ctx.actions.declare_file(ctx.label.name)
- script_paths = []
- for script in ctx.files.scripts:
- script_paths.append(script.short_path)
- builder_content = "\n".join([
- "#!/bin/bash",
- "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
- "export USERNAME=%s" % ctx.attr.username,
- "export IMAGE_PROJECT=%s" % ctx.attr.project,
- "export IMAGE_FAMILY=%s" % ctx.attr.family,
- "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)),
- "",
- ])
- ctx.actions.write(builder, builder_content, is_executable = True)
-
- # Note that the scripts should only be files, and should not include any
- # indirect transitive dependencies. The build script wouldn't work.
- return [DefaultInfo(
- executable = builder,
- runfiles = ctx.runfiles(
- files = ctx.files.scripts + ctx.files._builder + ctx.files.zone,
- ),
- )]
-
-vm_image_builder = rule(
- attrs = {
- "_builder": attr.label(
- executable = True,
- default = "//tools/vm:builder",
- cfg = "host",
- ),
- "username": attr.string(default = "$(whoami)"),
- "zone": attr.label(
- executable = True,
- default = "//tools/vm:zone",
- cfg = "host",
- ),
- "family": attr.string(mandatory = True),
- "project": attr.string(mandatory = True),
- "scripts": attr.label_list(allow_files = True),
- },
- executable = True,
- implementation = _vm_image_builder_impl,
-)
-
-# See vm_image_builder above.
-def _vm_image_impl(ctx):
- # Run the builder to generate our output.
- echo = ctx.actions.declare_file(ctx.label.name)
- resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
- command = "\n".join([
- "set -e",
- "image=$(%s)" % ctx.files.builder[0].path,
- "echo -ne \"#!/bin/bash\\necho ${image}\\n\" > %s" % echo.path,
- "chmod 0755 %s" % echo.path,
- ]),
- tools = [ctx.attr.builder],
- )
- ctx.actions.run_shell(
- tools = resolved_inputs,
- outputs = [echo],
- progress_message = "Building image...",
- execution_requirements = {"local": "true"},
- command = argv,
- input_manifests = runfiles_manifests,
- )
-
- # Return just the echo command. All of the builder runfiles have been
- # resolved and consumed in the generation of the trivial echo script.
- return [DefaultInfo(executable = echo)]
-
-_vm_image_test = rule(
- attrs = {
- "builder": attr.label(
- executable = True,
- cfg = "host",
- ),
- },
- test = True,
- implementation = _vm_image_impl,
-)
-
-def vm_image(name, **kwargs):
- vm_image_builder(
- name = name + "_builder",
- **kwargs
- )
- _vm_image_test(
- name = name,
- builder = ":" + name + "_builder",
- tags = [
- "local",
- "manual",
- ],
- )
-
-def _vm_test_impl(ctx):
- runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
-
- # Note that the remote execution case must actually generate an
- # intermediate target in order to collect all the relevant runfiles so that
- # they can be copied over for remote execution.
- runner_content = "\n".join([
- "#!/bin/bash",
- "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
- "export USERNAME=%s" % ctx.attr.username,
- "export IMAGE=$(%s)" % ctx.files.image[0].short_path,
- "export SUDO=%s" % "true" if ctx.attr.sudo else "false",
- "%s %s" % (
- ctx.executable.executer.short_path,
- " ".join([
- target.files_to_run.executable.short_path
- for target in ctx.attr.targets
- ]),
- ),
- "",
- ])
- ctx.actions.write(runner, runner_content, is_executable = True)
-
- # Return with all transitive files.
- runfiles = ctx.runfiles(
- transitive_files = depset(transitive = [
- depset(target.data_runfiles.files)
- for target in ctx.attr.targets
- if hasattr(target, "data_runfiles")
- ]),
- files = ctx.files.executer + ctx.files.zone + ctx.files.image +
- ctx.files.targets,
- collect_default = True,
- collect_data = True,
- )
- return [DefaultInfo(executable = runner, runfiles = runfiles)]
-
-_vm_test = rule(
- attrs = {
- "image": attr.label(
- executable = True,
- default = "//tools/vm:ubuntu1804",
- cfg = "host",
- ),
- "executer": attr.label(
- executable = True,
- default = "//tools/vm:executer",
- cfg = "host",
- ),
- "username": attr.string(default = "$(whoami)"),
- "zone": attr.label(
- executable = True,
- default = "//tools/vm:zone",
- cfg = "host",
- ),
- "sudo": attr.bool(default = True),
- "machine": attr.string(default = "n1-standard-1"),
- "targets": attr.label_list(
- mandatory = True,
- allow_empty = False,
- cfg = "target",
- ),
- },
- test = True,
- implementation = _vm_test_impl,
-)
-
-def vm_test(
- installers = None,
- **kwargs):
- """Runs the given targets as a remote test.
-
- Args:
- installer: Script to run before all targets.
- **kwargs: All test arguments. Should include targets and image.
- """
- targets = kwargs.pop("targets", [])
- if installers == None:
- installers = [
- "//tools/installers:head",
- "//tools/installers:images",
- ]
- targets = installers + targets
- if default_installer():
- targets = [default_installer()] + targets
- _vm_test(
- tags = [
- "local",
- "manual",
- ],
- targets = targets,
- local = 1,
- **kwargs
- )
diff --git a/tools/vm/execute.sh b/tools/vm/execute.sh
deleted file mode 100755
index 1f1f3ce01..000000000
--- a/tools/vm/execute.sh
+++ /dev/null
@@ -1,160 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Required input.
-if ! [[ -v IMAGE ]]; then
- echo "no image provided: set IMAGE."
- exit 1
-fi
-
-# Parameters.
-declare -r USERNAME=${USERNAME:-test}
-declare -r KEYNAME=$(mktemp --tmpdir -u key-XXXXXX)
-declare -r SSHKEYS=$(mktemp --tmpdir -u sshkeys-XXXXXX)
-declare -r INSTANCE_NAME=$(mktemp -u test-XXXXXX | tr A-Z a-z)
-declare -r MACHINE=${MACHINE:-n1-standard-1}
-declare -r ZONE=${ZONE:-us-central1-f}
-declare -r SUDO=${SUDO:-false}
-
-# Standard arguments (applies only on script execution).
-declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
-
-# This script is executed as a test rule, which will reset the value of HOME.
-# Unfortunately, it is needed to load the gconfig credentials. We will reset
-# HOME when we actually execute in the remote environment, defined below.
-export HOME=$(eval echo ~$(whoami))
-
-# Generate unique keys for this test.
-[[ -f "${KEYNAME}" ]] || ssh-keygen -t rsa -N "" -f "${KEYNAME}" -C "${USERNAME}"
-cat > "${SSHKEYS}" <<EOF
-${USERNAME}:$(cat ${KEYNAME}.pub)
-EOF
-
-# Start a unique instance. This means that we first generate a unique set of ssh
-# keys to ensure that only we have access to this instance. Note that we must
-# constrain ourselves to Haswell or greater in order to have nested
-# virtualization available.
-gcloud compute instances create \
- --min-cpu-platform "Intel Haswell" \
- --preemptible \
- --no-scopes \
- --metadata block-project-ssh-keys=TRUE \
- --metadata-from-file ssh-keys="${SSHKEYS}" \
- --machine-type "${MACHINE}" \
- --image "${IMAGE}" \
- --zone "${ZONE}" \
- "${INSTANCE_NAME}"
-function cleanup {
- gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}"
-}
-trap cleanup EXIT
-
-# Wait for the instance to become available (up to 5 minutes).
-declare timeout=300
-declare success=0
-declare -r start=$(date +%s)
-declare -r end=$((${start}+${timeout}))
-while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
- if gcloud compute ssh --ssh-key-file="${KEYNAME}" --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
- success=$((${success}+1))
- fi
-done
-if [[ "${success}" -eq "0" ]]; then
- echo "connect timed out after ${timeout} seconds."
- exit 1
-fi
-
-# Copy the local directory over.
-tar czf - --dereference --exclude=.git . |
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- tar xzf -
-
-# Execute the command remotely.
-for cmd; do
- # Setup relevant environment.
- #
- # N.B. This is not a complete test environment, but is complete enough to
- # provide rudimentary sharding and test output support.
- declare -a PREFIX=( "env" )
- if [[ -v TEST_SHARD_INDEX ]]; then
- PREFIX+=( "TEST_SHARD_INDEX=${TEST_SHARD_INDEX}" )
- fi
- if [[ -v TEST_SHARD_STATUS_FILE ]]; then
- SHARD_STATUS_FILE=$(mktemp -u test-shard-status-XXXXXX)
- PREFIX+=( "TEST_SHARD_STATUS_FILE=/tmp/${SHARD_STATUS_FILE}" )
- fi
- if [[ -v TEST_TOTAL_SHARDS ]]; then
- PREFIX+=( "TEST_TOTAL_SHARDS=${TEST_TOTAL_SHARDS}" )
- fi
- if [[ -v TEST_TMPDIR ]]; then
- REMOTE_TMPDIR=$(mktemp -u test-XXXXXX)
- PREFIX+=( "TEST_TMPDIR=/tmp/${REMOTE_TMPDIR}" )
- # Create remotely.
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- mkdir -p "/tmp/${REMOTE_TMPDIR}"
- fi
- if [[ -v XML_OUTPUT_FILE ]]; then
- TEST_XML_OUTPUT=$(mktemp -u xml-output-XXXXXX)
- PREFIX+=( "XML_OUTPUT_FILE=/tmp/${TEST_XML_OUTPUT}" )
- fi
- if [[ "${SUDO}" == "true" ]]; then
- PREFIX+=( "sudo" "-E" )
- fi
-
- # Execute the command.
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- "${PREFIX[@]}" "${cmd}"
-
- # Collect relevant results.
- if [[ -v TEST_SHARD_STATUS_FILE ]]; then
- gcloud compute scp \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${SHARD_STATUS_FILE}" \
- "${TEST_SHARD_STATUS_FILE}" 2>/dev/null || true # Allowed to fail.
- fi
- if [[ -v XML_OUTPUT_FILE ]]; then
- gcloud compute scp \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${TEST_XML_OUTPUT}" \
- "${XML_OUTPUT_FILE}" 2>/dev/null || true # Allowed to fail.
- fi
-
- # Clean up the temporary directory.
- if [[ -v TEST_TMPDIR ]]; then
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- rm -rf "/tmp/${REMOTE_TMPDIR}"
- fi
-done
diff --git a/tools/vm/ubuntu1604/10_core.sh b/tools/vm/ubuntu1604/10_core.sh
deleted file mode 100755
index 629f7cf7a..000000000
--- a/tools/vm/ubuntu1604/10_core.sh
+++ /dev/null
@@ -1,43 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Install all essential build tools.
-while true; do
- if (apt-get update && apt-get install -y \
- make \
- git-core \
- build-essential \
- linux-headers-$(uname -r) \
- pkg-config); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Install a recent go toolchain.
-if ! [[ -d /usr/local/go ]]; then
- wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz
- tar -xvf go1.13.5.linux-amd64.tar.gz
- mv go /usr/local
-fi
-
-# Link the Go binary from /usr/bin; replacing anything there.
-(cd /usr/bin && rm -f go && ln -fs /usr/local/go/bin/go go)
diff --git a/tools/vm/ubuntu1604/15_gcloud.sh b/tools/vm/ubuntu1604/15_gcloud.sh
deleted file mode 100755
index bc2e5eccc..000000000
--- a/tools/vm/ubuntu1604/15_gcloud.sh
+++ /dev/null
@@ -1,50 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Install all essential build tools.
-while true; do
- if (apt-get update && apt-get install -y \
- apt-transport-https \
- ca-certificates \
- gnupg); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Add gcloud repositories.
-echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \
- tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
-
-# Add the appropriate key.
-curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \
- apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
-
-# Install the gcloud SDK.
-while true; do
- if (apt-get update && apt-get install -y google-cloud-sdk); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
diff --git a/tools/vm/ubuntu1604/20_bazel.sh b/tools/vm/ubuntu1604/20_bazel.sh
deleted file mode 100755
index bb7afa676..000000000
--- a/tools/vm/ubuntu1604/20_bazel.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-declare -r BAZEL_VERSION=2.0.0
-
-# Install bazel dependencies.
-while true; do
- if (apt-get update && apt-get install -y \
- openjdk-8-jdk-headless \
- unzip); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Use the release installer.
-curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
-chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
-./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
-rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/tools/vm/ubuntu1604/30_docker.sh b/tools/vm/ubuntu1604/30_docker.sh
deleted file mode 100755
index d393133e4..000000000
--- a/tools/vm/ubuntu1604/30_docker.sh
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Add dependencies.
-while true; do
- if (apt-get update && apt-get install -y \
- apt-transport-https \
- ca-certificates \
- curl \
- gnupg-agent \
- software-properties-common); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Install the key.
-curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
-
-# Add the repository.
-add-apt-repository \
- "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
- $(lsb_release -cs) \
- stable"
-
-# Install docker.
-while true; do
- if (apt-get update && apt-get install -y \
- docker-ce \
- docker-ce-cli \
- containerd.io); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Enable experimental features, for cross-building aarch64 images.
-# Enable Docker IPv6.
-cat > /etc/docker/daemon.json <<EOF
-{
- "experimental": true,
- "fixed-cidr-v6": "2001:db8:1::/64",
- "ipv6": true
-}
-EOF
diff --git a/tools/vm/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh
deleted file mode 100755
index d3b96c9ad..000000000
--- a/tools/vm/ubuntu1604/40_kokoro.sh
+++ /dev/null
@@ -1,72 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Declare kokoro's required public keys.
-declare -r ssh_public_keys=(
- "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com"
- "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com"
-)
-
-# Install dependencies.
-while true; do
- if (apt-get update && apt-get install -y \
- rsync \
- coreutils \
- python-psutil \
- qemu-kvm \
- python-pip \
- python3-pip \
- zip); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# junitparser is used to merge junit xml files.
-pip install --no-cache-dir junitparser
-
-# We need a kbuilder user, which may already exist.
-useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true
-
-# We need to provision appropriate keys.
-mkdir -p ~kbuilder/.ssh
-(IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys
-chmod 0600 ~kbuilder/.ssh/authorized_keys
-chown -R kbuilder ~kbuilder/.ssh
-
-# Give passwordless sudo access.
-cat > /etc/sudoers.d/kokoro <<EOF
-kbuilder ALL=(ALL) NOPASSWD:ALL
-EOF
-
-# Ensure we can run Docker without sudo.
-usermod -aG docker kbuilder
-
-# Ensure that we can access kvm.
-usermod -aG kvm kbuilder
-
-# Ensure that /tmpfs exists and is writable by kokoro.
-#
-# Note that kokoro will typically attach a second disk (sdb) to the instance
-# that is used for the /tmpfs volume. In the future we could setup an init
-# script that formats and mounts this here; however, we don't expect our build
-# artifacts to be that large.
-mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY
diff --git a/tools/vm/ubuntu1604/BUILD b/tools/vm/ubuntu1604/BUILD
deleted file mode 100644
index ab1df0c4c..000000000
--- a/tools/vm/ubuntu1604/BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-package(licenses = ["notice"])
-
-filegroup(
- name = "ubuntu1604",
- srcs = glob(["*.sh"]),
- visibility = ["//:sandbox"],
-)
diff --git a/tools/vm/ubuntu1804/BUILD b/tools/vm/ubuntu1804/BUILD
deleted file mode 100644
index 0c8856dde..000000000
--- a/tools/vm/ubuntu1804/BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-package(licenses = ["notice"])
-
-alias(
- name = "ubuntu1804",
- actual = "//tools/vm/ubuntu1604",
- visibility = ["//:sandbox"],
-)
diff --git a/website/blog/README.md b/website/blog/README.md
new file mode 100644
index 000000000..e1d685288
--- /dev/null
+++ b/website/blog/README.md
@@ -0,0 +1,62 @@
+# gVisor blog
+
+The gVisor blog is owned and run by the gVisor team.
+
+## Contact
+
+Reach out to us on [gitter](https://gitter.im/gvisor/community) or the
+[mailing list](https://groups.google.com/forum/#!forum/gvisor-users) if you
+would like to write a blog post.
+
+## Submit a Post
+
+Anyone can write a blog post and submit it for review. Purely commercial content
+or vendor pitches are not allowed. Please refer to the
+[blog guidelines](#blog-guidelines) for more guidance about content is that
+allowed.
+
+To submit a blog post, follow the steps below.
+
+1. [Sign the Contributor License Agreements](https://gvisor.dev/contributing/)
+ if you have not yet done so.
+1. Familiarize yourself with the Markdown format for the
+ [existing blog posts](https://github.com/google/gvisor/tree/master/website/blog).
+1. Write your blog post in a text editor of your choice.
+1. (Optional) If you need help with markdown, check out
+ [StakEdit](https://stackedit.io/app#) or read
+ [Jekyll's formatting reference](https://jekyllrb.com/docs/posts/#creating-posts)
+ for more information.
+1. Click **Add file** > **Create new file**.
+1. Paste your content into the editor and save it. Name the file in the
+ following way: *[BLOG] Your proposed title* , but don’t put the date in the
+ file name. The blog reviewers will work with you on the final file name, and
+ the date on which the blog will be published.
+1. When you save the file, GitHub will walk you through the pull request (PR)
+ process.
+1. Send us a message on [gitter](https://gitter.im/gvisor/community) with a
+ link to your recently created PR.
+1. A reviewer will be assigned to the pull request. They check your submission,
+ and work with you on feedback and final details. When the pull request is
+ approved, the blog will be scheduled for publication.
+
+### Blog Guidelines {#blog-guidelines}
+
+#### Suitable content:
+
+- **Original content only**
+- gVisor features or project updates
+- Tutorials and demos
+- Use cases
+- Content that is specific to a vendor or platform about gVisor installation
+ and use
+
+#### Unsuitable Content:
+
+- Blogs with no content relevant to gVisor
+- Vendor pitches
+
+## Review Process
+
+Each blog post should be approved by at least one person on the team. Once all
+of the review comments have been addressed and approved, a member of the team
+will schedule publication of the blog post.
diff --git a/website/blog/index.html b/website/blog/index.html
index 5c67c95fc..272917fc4 100644
--- a/website/blog/index.html
+++ b/website/blog/index.html
@@ -20,3 +20,8 @@ pagination:
{% if paginator.total_pages > 1 %}
{% include paginator.html %}
{% endif %}
+
+<hr>
+
+If you would like to contribute to the gVisor blog check out the
+<a href="https://github.com/google/gvisor/tree/master/website/blog">instructions</a>.