summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc29
-rw-r--r--.buildkite/hooks/post-command60
-rw-r--r--.buildkite/hooks/pre-command33
-rw-r--r--.buildkite/pipeline.yaml204
-rw-r--r--.github/pull_request_template.md5
-rw-r--r--.github/workflows/build.yml12
-rw-r--r--.github/workflows/go.yml48
-rw-r--r--.github/workflows/issue_reviver.yml4
-rw-r--r--.github/workflows/labeler.yml3
-rw-r--r--.github/workflows/stale.yml4
-rw-r--r--.gitignore5
-rw-r--r--.travis.yml47
-rw-r--r--BUILD32
-rw-r--r--Makefile474
-rw-r--r--README.md2
-rw-r--r--WORKSPACE73
-rw-r--r--g3doc/architecture_guide/performance.md2
-rw-r--r--g3doc/proposals/runtime_dedicate_os_thread.md188
-rw-r--r--go.mod4
-rw-r--r--go.sum6
-rw-r--r--images/BUILD10
-rw-r--r--images/Makefile107
-rw-r--r--images/agent/Dockerfile12
-rw-r--r--images/agent/README.md7
-rw-r--r--images/arm-qemu/Dockerfile.x86_6412
-rwxr-xr-ximages/arm-qemu/initramfs/init39
-rwxr-xr-ximages/arm-qemu/test.sh (renamed from tools/vm/zone.sh)13
-rw-r--r--images/basic/ping4test/Dockerfile7
-rw-r--r--images/basic/ping4test/ping4.sh25
-rw-r--r--images/basic/ping6test/Dockerfile7
-rw-r--r--images/basic/ping6test/ping6.sh32
-rw-r--r--images/benchmarks/absl/Dockerfile.x86_64 (renamed from images/benchmarks/absl/Dockerfile)1
-rw-r--r--images/benchmarks/hey/Dockerfile13
-rw-r--r--images/benchmarks/runsc/Dockerfile.x86_64 (renamed from images/benchmarks/runsc/Dockerfile)1
-rw-r--r--images/default/Dockerfile37
-rw-r--r--images/runtimes/go1.12/Dockerfile.x86_64 (renamed from images/runtimes/go1.12/Dockerfile)0
-rw-r--r--nogo.yaml140
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/errqueue.go93
-rw-r--r--pkg/abi/linux/fcntl.go2
-rw-r--r--pkg/abi/linux/fuse.go8
-rw-r--r--pkg/abi/linux/sem.go35
-rw-r--r--pkg/abi/linux/socket.go20
-rw-r--r--pkg/control/server/server.go5
-rw-r--r--pkg/coverage/coverage.go130
-rw-r--r--pkg/cpuid/cpuid.go11
-rw-r--r--pkg/cpuid/cpuid_x86.go11
-rw-r--r--pkg/crypto/BUILD12
-rw-r--r--pkg/crypto/crypto.go (renamed from pkg/sleep/empty.s)5
-rw-r--r--pkg/crypto/crypto_stdlib.go32
-rw-r--r--pkg/flipcall/BUILD3
-rw-r--r--pkg/flipcall/packet_window_mmap_amd64.go (renamed from pkg/flipcall/packet_window_mmap.go)0
-rw-r--r--pkg/flipcall/packet_window_mmap_arm64.go (renamed from pkg/syncevent/waiter_asm_unsafe.go)13
-rw-r--r--pkg/log/json.go14
-rw-r--r--pkg/log/log.go2
-rw-r--r--pkg/merkletree/merkletree.go10
-rw-r--r--pkg/p9/handlers.go81
-rw-r--r--pkg/p9/p9test/client_test.go15
-rw-r--r--pkg/p9/server.go14
-rw-r--r--pkg/p9/transport_test.go16
-rw-r--r--pkg/pool/pool.go1
-rw-r--r--pkg/refsvfs2/BUILD2
-rw-r--r--pkg/refsvfs2/refs_template.go9
-rw-r--r--pkg/safemem/block_unsafe.go20
-rw-r--r--pkg/seccomp/seccomp.go2
-rw-r--r--pkg/segment/test/set_functions.go1
-rw-r--r--pkg/sentry/arch/arch.go15
-rw-r--r--pkg/sentry/arch/arch_state_x86.go17
-rw-r--r--pkg/sentry/arch/signal.go39
-rw-r--r--pkg/sentry/control/pprof.go287
-rw-r--r--pkg/sentry/control/state.go1
-rw-r--r--pkg/sentry/fdimport/fdimport.go1
-rw-r--r--pkg/sentry/fs/copy_up.go13
-rw-r--r--pkg/sentry/fs/copy_up_test.go2
-rw-r--r--pkg/sentry/fs/filetest/filetest.go4
-rw-r--r--pkg/sentry/fs/fs.go2
-rw-r--r--pkg/sentry/fs/gofer/attr.go2
-rw-r--r--pkg/sentry/fs/gofer/inode.go3
-rw-r--r--pkg/sentry/fs/host/inode.go4
-rw-r--r--pkg/sentry/fs/proc/sys.go1
-rw-r--r--pkg/sentry/fs/ramfs/socket.go3
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go4
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go4
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_control.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_test.go10
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/directory.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/file.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go45
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go12
-rw-r--r--pkg/sentry/fsimpl/fuse/request_response.go5
-rw-r--r--pkg/sentry/fsimpl/host/host.go5
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go23
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go17
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go1
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go5
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go94
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go232
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go7
-rw-r--r--pkg/sentry/kernel/fasync/BUILD2
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go96
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go11
-rw-r--r--pkg/sentry/kernel/kernel.go50
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go28
-rw-r--r--pkg/sentry/kernel/ptrace.go4
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go115
-rw-r--r--pkg/sentry/kernel/shm/BUILD3
-rw-r--r--pkg/sentry/kernel/signal.go4
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go4
-rw-r--r--pkg/sentry/kernel/task_exit.go8
-rw-r--r--pkg/sentry/kernel/task_signals.go16
-rw-r--r--pkg/sentry/memmap/memmap.go2
-rw-r--r--pkg/sentry/mm/aio_context.go79
-rw-r--r--pkg/sentry/mm/aio_context_state.go4
-rw-r--r--pkg/sentry/mm/lifecycle.go2
-rw-r--r--pkg/sentry/mm/mm_test.go43
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go62
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go14
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go40
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go1
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go2
-rw-r--r--pkg/sentry/platform/ring0/BUILD11
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s165
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go8
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go16
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s17
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD17
-rw-r--r--pkg/sentry/socket/control/control.go125
-rw-r--r--pkg/sentry/socket/control/control_test.go59
-rw-r--r--pkg/sentry/socket/hostinet/socket.go179
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go1036
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go6
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/stack.go30
-rw-r--r--pkg/sentry/socket/socket.go304
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go40
-rw-r--r--pkg/sentry/socket/unix/unix.go12
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go2
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/socket.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go28
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go54
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go18
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go7
-rw-r--r--pkg/sentry/vfs/epoll.go10
-rw-r--r--pkg/sentry/vfs/file_description.go73
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go36
-rw-r--r--pkg/sentry/vfs/save_restore.go25
-rw-r--r--pkg/shim/v1/proc/process.go1
-rw-r--r--pkg/shim/v1/shim/BUILD1
-rw-r--r--pkg/shim/v1/shim/shim.go (renamed from pkg/sleep/commit_asm.go)13
-rw-r--r--pkg/shim/v1/utils/utils.go1
-rw-r--r--pkg/shim/v2/BUILD1
-rw-r--r--pkg/shim/v2/service.go86
-rw-r--r--pkg/sleep/BUILD4
-rw-r--r--pkg/sleep/commit_amd64.s35
-rw-r--r--pkg/sleep/commit_arm64.s38
-rw-r--r--pkg/sleep/commit_noasm.go33
-rw-r--r--pkg/sleep/sleep_unsafe.go27
-rw-r--r--pkg/state/tests/integer_test.go26
-rw-r--r--pkg/state/tests/register_test.go11
-rw-r--r--pkg/state/tests/struct_test.go11
-rw-r--r--pkg/state/types.go71
-rw-r--r--pkg/sync/BUILD31
-rw-r--r--pkg/sync/atomicptrmaptest/BUILD57
-rw-r--r--pkg/sync/atomicptrmaptest/atomicptrmap.go (renamed from tools/vm/test.cc)15
-rw-r--r--pkg/sync/atomicptrmaptest/atomicptrmap_test.go635
-rw-r--r--pkg/sync/generic_atomicptr_unsafe.go (renamed from pkg/sync/atomicptr_unsafe.go)4
-rw-r--r--pkg/sync/generic_atomicptrmap_unsafe.go503
-rw-r--r--pkg/sync/generic_seqatomic_unsafe.go (renamed from pkg/sync/seqatomic_unsafe.go)21
-rw-r--r--pkg/sync/goyield_go113_unsafe.go (renamed from pkg/sync/spin_legacy_unsafe.go)7
-rw-r--r--pkg/sync/goyield_unsafe.go (renamed from pkg/sync/spin_unsafe.go)6
-rw-r--r--pkg/sync/memmove_unsafe.go28
-rw-r--r--pkg/sync/norace_unsafe.go11
-rw-r--r--pkg/sync/race_amd64.s (renamed from pkg/syncevent/waiter_amd64.s)17
-rw-r--r--pkg/sync/race_arm64.s (renamed from pkg/syncevent/waiter_arm64.s)17
-rw-r--r--pkg/sync/race_unsafe.go6
-rw-r--r--pkg/sync/runtime_unsafe.go129
-rw-r--r--pkg/sync/rwmutex_test.go2
-rw-r--r--pkg/sync/rwmutex_unsafe.go145
-rw-r--r--pkg/sync/seqcount.go34
-rw-r--r--pkg/sync/seqcount_test.go53
-rw-r--r--pkg/syncevent/BUILD4
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_unsafe.go59
-rw-r--r--pkg/syserr/host_linux.go2
-rw-r--r--pkg/syserr/netstack.go95
-rw-r--r--pkg/tcpip/BUILD14
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go57
-rw-r--r--pkg/tcpip/buffer/view.go37
-rw-r--r--pkg/tcpip/buffer/view_test.go68
-rw-r--r--pkg/tcpip/checker/checker.go333
-rw-r--r--pkg/tcpip/header/BUILD1
-rw-r--r--pkg/tcpip/header/checksum_test.go94
-rw-r--r--pkg/tcpip/header/icmpv4.go30
-rw-r--r--pkg/tcpip/header/icmpv6.go40
-rw-r--r--pkg/tcpip/header/ipv4.go206
-rw-r--r--pkg/tcpip/header/ipv4_test.go179
-rw-r--r--pkg/tcpip/header/ipv6.go53
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go336
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go356
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go42
-rw-r--r--pkg/tcpip/header/ipv6_test.go44
-rw-r--r--pkg/tcpip/header/mld.go7
-rw-r--r--pkg/tcpip/header/ndp_options.go2
-rw-r--r--pkg/tcpip/link/channel/channel.go14
-rw-r--r--pkg/tcpip/link/ethernet/BUILD16
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go10
-rw-r--r--pkg/tcpip/link/ethernet/ethernet_test.go71
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go13
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go16
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go6
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go4
-rw-r--r--pkg/tcpip/link/pipe/pipe.go2
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go8
-rw-r--r--pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go1
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go2
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go32
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go2
-rw-r--r--pkg/tcpip/link/tun/device.go2
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go7
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD2
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go10
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go148
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go225
-rw-r--r--pkg/tcpip/network/ip/BUILD1
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go598
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go724
-rw-r--r--pkg/tcpip/network/ip_test.go125
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go177
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go417
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go168
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go179
-rw-r--r--pkg/tcpip/network/ipv6/BUILD18
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go47
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go179
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go347
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go299
-rw-r--r--pkg/tcpip/network/ipv6/mld.go262
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go297
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go319
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go124
-rw-r--r--pkg/tcpip/network/multicast_group_test.go1261
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go7
-rw-r--r--pkg/tcpip/socketops.go392
-rw-r--r--pkg/tcpip/stack/BUILD7
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go157
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go22
-rw-r--r--pkg/tcpip/stack/forwarding_test.go34
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go135
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go110
-rw-r--r--pkg/tcpip/stack/ndp_test.go237
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go100
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go491
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go137
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go457
-rw-r--r--pkg/tcpip/stack/nic.go126
-rw-r--r--pkg/tcpip/stack/nud.go21
-rw-r--r--pkg/tcpip/stack/nud_test.go53
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go53
-rw-r--r--pkg/tcpip/stack/route.go235
-rw-r--r--pkg/tcpip/stack/stack.go73
-rw-r--r--pkg/tcpip/stack/stack_test.go321
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go18
-rw-r--r--pkg/tcpip/stack/transport_test.go27
-rw-r--r--pkg/tcpip/tcpip.go253
-rw-r--r--pkg/tcpip/tcpip_test.go40
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go65
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go41
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go31
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go532
-rw-r--r--pkg/tcpip/tests/integration/route_test.go69
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go106
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go86
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go152
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go8
-rw-r--r--pkg/tcpip/transport/tcp/connect.go46
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go488
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go50
-rw-r--r--pkg/tcpip/transport/tcp/segment.go16
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go13
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go40
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go42
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go397
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go23
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go14
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go445
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go7
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go209
-rw-r--r--pkg/test/criutil/criutil.go12
-rw-r--r--pkg/test/dockerutil/container.go106
-rw-r--r--pkg/test/dockerutil/dockerutil.go10
-rw-r--r--pkg/test/dockerutil/exec.go5
-rw-r--r--pkg/test/dockerutil/profile.go161
-rw-r--r--pkg/test/dockerutil/profile_test.go59
-rw-r--r--pkg/test/testutil/BUILD2
-rw-r--r--pkg/test/testutil/sh.go515
-rw-r--r--pkg/test/testutil/testutil.go29
-rw-r--r--pkg/urpc/urpc.go27
-rw-r--r--pkg/usermem/usermem.go2
-rw-r--r--pkg/waiter/waiter.go11
-rw-r--r--pkg/waiter/waiter_test.go20
-rw-r--r--runsc/boot/BUILD1
-rw-r--r--runsc/boot/compat.go2
-rw-r--r--runsc/boot/controller.go24
-rw-r--r--runsc/boot/filter/config.go94
-rw-r--r--runsc/boot/loader.go13
-rw-r--r--runsc/boot/vfs.go70
-rw-r--r--runsc/cli/main.go4
-rw-r--r--runsc/cmd/BUILD2
-rw-r--r--runsc/cmd/checkpoint.go2
-rw-r--r--runsc/cmd/debug.go247
-rw-r--r--runsc/cmd/delete.go2
-rw-r--r--runsc/cmd/do.go2
-rw-r--r--runsc/cmd/events.go2
-rw-r--r--runsc/cmd/exec.go2
-rw-r--r--runsc/cmd/kill.go2
-rw-r--r--runsc/cmd/list.go8
-rw-r--r--runsc/cmd/pause.go2
-rw-r--r--runsc/cmd/ps.go2
-rw-r--r--runsc/cmd/resume.go2
-rw-r--r--runsc/cmd/start.go2
-rw-r--r--runsc/cmd/state.go2
-rw-r--r--runsc/cmd/symbolize.go91
-rw-r--r--runsc/cmd/wait.go2
-rw-r--r--runsc/config/config.go2
-rw-r--r--runsc/container/BUILD4
-rw-r--r--runsc/container/container.go182
-rw-r--r--runsc/container/container_test.go56
-rw-r--r--runsc/container/multi_container_test.go6
-rw-r--r--runsc/container/state_file.go236
-rw-r--r--runsc/fsgofer/BUILD3
-rw-r--r--runsc/fsgofer/fsgofer.go79
-rw-r--r--runsc/fsgofer/fsgofer_test.go217
-rw-r--r--runsc/sandbox/network.go2
-rw-r--r--runsc/sandbox/sandbox.go108
-rw-r--r--test/benchmarks/BUILD11
-rw-r--r--test/benchmarks/README.md14
-rw-r--r--test/benchmarks/base/BUILD9
-rw-r--r--test/benchmarks/base/size_test.go10
-rw-r--r--test/benchmarks/base/startup_test.go12
-rw-r--r--test/benchmarks/base/sysbench_test.go23
-rw-r--r--test/benchmarks/database/BUILD13
-rw-r--r--test/benchmarks/database/database.go15
-rw-r--r--test/benchmarks/database/redis_test.go38
-rw-r--r--test/benchmarks/defs.bzl14
-rw-r--r--test/benchmarks/fs/BUILD6
-rw-r--r--test/benchmarks/fs/bazel_test.go33
-rw-r--r--test/benchmarks/fs/fio_test.go68
-rw-r--r--test/benchmarks/harness/harness.go18
-rw-r--r--test/benchmarks/harness/machine.go18
-rw-r--r--test/benchmarks/media/BUILD10
-rw-r--r--test/benchmarks/media/ffmpeg_test.go18
-rw-r--r--test/benchmarks/media/media.go15
-rw-r--r--test/benchmarks/ml/BUILD10
-rw-r--r--test/benchmarks/ml/ml.go15
-rw-r--r--test/benchmarks/ml/tensorflow_test.go19
-rw-r--r--test/benchmarks/network/BUILD74
-rw-r--r--test/benchmarks/network/httpd_test.go60
-rw-r--r--test/benchmarks/network/iperf_test.go34
-rw-r--r--test/benchmarks/network/network.go61
-rw-r--r--test/benchmarks/network/nginx_test.go66
-rw-r--r--test/benchmarks/network/node_test.go18
-rw-r--r--test/benchmarks/network/ruby_test.go19
-rw-r--r--test/benchmarks/network/static_server.go87
-rw-r--r--test/benchmarks/tools/fio.go17
-rw-r--r--test/benchmarks/tools/hey.go13
-rw-r--r--test/benchmarks/tools/iperf.go17
-rw-r--r--test/benchmarks/tools/redis.go23
-rw-r--r--test/benchmarks/tools/sysbench.go101
-rw-r--r--test/cmd/test_app/fds.go5
-rw-r--r--test/e2e/integration_test.go57
-rw-r--r--test/e2e/regression_test.go2
-rw-r--r--test/fuse/BUILD5
-rw-r--r--test/fuse/linux/BUILD13
-rw-r--r--test/fuse/linux/mount_test.cc83
-rw-r--r--test/iptables/filter_output.go2
-rw-r--r--test/packetdrill/BUILD9
-rw-r--r--test/packetdrill/defs.bzl6
-rwxr-xr-xtest/packetdrill/packetdrill_test.sh13
-rw-r--r--test/packetimpact/runner/BUILD1
-rw-r--r--test/packetimpact/runner/defs.bzl18
-rw-r--r--test/packetimpact/runner/dut.go449
-rw-r--r--test/packetimpact/testbench/BUILD1
-rw-r--r--test/packetimpact/testbench/connections.go17
-rw-r--r--test/packetimpact/testbench/layers.go49
-rw-r--r--test/packetimpact/testbench/testbench.go98
-rw-r--r--test/packetimpact/tests/BUILD20
-rw-r--r--test/packetimpact/tests/ipv4_fragment_reassembly_test.go65
-rw-r--r--test/packetimpact/tests/ipv6_fragment_icmp_error_test.go3
-rw-r--r--test/packetimpact/tests/ipv6_fragment_reassembly_test.go52
-rw-r--r--test/packetimpact/tests/tcp_zero_receive_window_test.go125
-rw-r--r--test/packetimpact/tests/udp_recv_mcast_bcast_test.go2
-rw-r--r--test/perf/BUILD3
-rw-r--r--test/root/BUILD14
-rw-r--r--test/root/cgroup_test.go5
-rw-r--r--test/root/crictl_test.go2
-rw-r--r--test/runner/defs.bzl45
-rw-r--r--test/runtimes/BUILD12
-rw-r--r--test/runtimes/runner/lib/lib.go27
-rw-r--r--test/runtimes/runner/main.go14
-rw-r--r--test/syscalls/BUILD76
-rw-r--r--test/syscalls/linux/BUILD102
-rw-r--r--test/syscalls/linux/chown.cc11
-rw-r--r--test/syscalls/linux/exceptions.cc10
-rw-r--r--test/syscalls/linux/fcntl.cc486
-rw-r--r--test/syscalls/linux/getdents.cc3
-rw-r--r--test/syscalls/linux/kill.cc6
-rw-r--r--test/syscalls/linux/mount.cc36
-rw-r--r--test/syscalls/linux/open.cc12
-rw-r--r--test/syscalls/linux/open_create.cc8
-rw-r--r--test/syscalls/linux/pipe.cc10
-rw-r--r--test/syscalls/linux/proc.cc50
-rw-r--r--test/syscalls/linux/proc_net.cc8
-rw-r--r--test/syscalls/linux/proc_net_unix.cc2
-rw-r--r--test/syscalls/linux/proc_pid_uid_gid_map.cc9
-rw-r--r--test/syscalls/linux/raw_socket.cc45
-rw-r--r--test/syscalls/linux/semaphore.cc161
-rw-r--r--test/syscalls/linux/signalfd.cc2
-rw-r--r--test/syscalls/linux/socket.cc6
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc24
-rw-r--r--test/syscalls/linux/socket_generic.cc25
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc16
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc57
-rw-r--r--test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc59
-rw-r--r--test/syscalls/linux/socket_ip_udp_unbound_external_networking.h46
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc84
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc52
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h21
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc4
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound.cc131
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound.h29
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc90
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h31
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_external_networking_test.cc39
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc32
-rw-r--r--test/syscalls/linux/socket_test_util.cc11
-rw-r--r--test/syscalls/linux/socket_test_util.h1
-rw-r--r--test/syscalls/linux/socket_unix_unbound_filesystem.cc16
-rw-r--r--test/syscalls/linux/tuntap.cc19
-rw-r--r--test/syscalls/linux/udp_socket.cc126
-rw-r--r--tools/bazel.mk258
-rw-r--r--tools/bazeldefs/BUILD59
-rw-r--r--tools/bazeldefs/cc.bzl12
-rw-r--r--tools/bazeldefs/defs.bzl45
-rw-r--r--tools/bazeldefs/go.bzl4
-rw-r--r--tools/bigquery/bigquery.go7
-rw-r--r--tools/checkescape/checkescape.go12
-rw-r--r--tools/checkescape/test1/test1.go13
-rw-r--r--tools/defs.bzl8
-rwxr-xr-xtools/go_branch.sh10
-rw-r--r--tools/go_generics/defs.bzl4
-rw-r--r--tools/go_generics/generics.go2
-rw-r--r--tools/go_marshal/gomarshal/generator.go58
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go4
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go4
-rw-r--r--tools/images.mk169
-rw-r--r--tools/installers/BUILD10
-rwxr-xr-xtools/installers/containerd.sh17
-rw-r--r--tools/nogo/BUILD2
-rw-r--r--tools/nogo/config-schema.json97
-rw-r--r--tools/nogo/filter/main.go9
-rw-r--r--tools/parsers/go_parser_test.go22
-rw-r--r--tools/vm/BUILD63
-rw-r--r--tools/vm/README.md48
-rwxr-xr-xtools/vm/build.sh117
-rw-r--r--tools/vm/defs.bzl202
-rwxr-xr-xtools/vm/execute.sh160
-rwxr-xr-xtools/vm/ubuntu1604/10_core.sh43
-rwxr-xr-xtools/vm/ubuntu1604/15_gcloud.sh50
-rwxr-xr-xtools/vm/ubuntu1604/20_bazel.sh38
-rwxr-xr-xtools/vm/ubuntu1604/30_docker.sh64
-rwxr-xr-xtools/vm/ubuntu1604/40_kokoro.sh72
-rw-r--r--tools/vm/ubuntu1604/BUILD7
-rw-r--r--tools/vm/ubuntu1804/BUILD7
-rwxr-xr-xtools/workspace_status.sh2
-rw-r--r--tools/yamltest/BUILD13
-rw-r--r--tools/yamltest/defs.bzl41
-rw-r--r--tools/yamltest/main.go133
-rw-r--r--website/BUILD3
-rw-r--r--website/blog/README.md62
-rw-r--r--website/blog/index.html5
-rw-r--r--website/cmd/server/BUILD3
-rw-r--r--website/cmd/server/main.go222
520 files changed, 21171 insertions, 10789 deletions
diff --git a/.bazelrc b/.bazelrc
index e2848ef07..413cee3b0 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# RBE requires a strong hash function, such as SHA256.
+# Ensure a strong hash function.
startup --host_jvm_args=-Dbazel.DigestFunction=SHA256
# Build with C++17.
@@ -21,26 +21,7 @@ build --cxxopt=-std=c++17
# Display the current git revision in the info block.
build --stamp --workspace_status_command tools/workspace_status.sh
-# Enable remote execution so actions are performed on the remote systems.
-build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com
-build:remote --bes_backend=buildeventservice.googleapis.com
-build:remote --bes_results_url="https://source.cloud.google.com/results/invocations"
-build:remote --bes_timeout=600s
-build:remote --project_id=gvisor-rbe
-build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance
-
-# Enable authentication. This will pick up application default credentials by
-# default. You can use --google_credentials=some_file.json to use a service
-# account credential instead.
-build:remote --google_default_credentials=true
-build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
-
-# Add a custom platform and toolchain that builds in a privileged docker
-# container, which is required by our syscall tests.
-build:remote --host_platform=//tools/bazeldefs:rbe_ubuntu1604
-build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default
-build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604
-build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604
-build:remote --crosstool_top=@rbe_default//cc:toolchain
-build:remote --jobs=300
-build:remote --remote_timeout=3600
+# Set flags for aarch64.
+build:cross-aarch64 --crosstool_top=@crosstool//:toolchains --compiler=gcc
+build:cross-aarch64 --cpu=aarch64
+build:cross-aarch64 --platforms=@io_bazel_rules_go//go/toolchain:linux_arm64
diff --git a/.buildkite/hooks/post-command b/.buildkite/hooks/post-command
new file mode 100644
index 000000000..8af1369a6
--- /dev/null
+++ b/.buildkite/hooks/post-command
@@ -0,0 +1,60 @@
+# Upload all relevant test failures.
+make -s testlogs 2>/dev/null | grep // | sort | uniq | (
+ declare log_count=0
+ while read target log; do
+ if test -z "${target}"; then
+ continue
+ fi
+
+ # N.B. If *all* tests fail due to some common cause, then we will
+ # end up spending way too much time uploading logs. Instead, we just
+ # upload the first 10 and stop. That is hopefully enough to debug.
+ #
+ # We include this test in the metadata, but note that we cannot
+ # upload the actual test logs. The user should rerun locally.
+ log_count=$((${log_count}+1))
+ if test "${log_count}" -ge 10; then
+ echo " * ${target} (no upload)" | \
+ buildkite-agent annotate --style error --context failures --append
+ else
+ buildkite-agent artifact upload "${log}"
+ echo " * [${target}](artifact://${log#/}) (${BUILDKITE_LABEL})" | \
+ buildkite-agent annotate --style error --context failures --append
+ fi
+ done
+)
+
+# Upload all profiles, and include in an annotation.
+declare profile_output=$(mktemp --tmpdir)
+for file in $(find /tmp/profile -name \*.pprof -print 2>/dev/null | sort); do
+ # Generate a link to the profile parsing function in gvisor.dev, which
+ # implicitly uses a prefix of https://storage.googleapis.com. Note that
+ # this relies on the specific BuildKite bucket location, and will break if
+ # this changes (although the artifacts will still exist and be just fine).
+ profile_name="${file#/tmp/profile/}"
+ profile_url="https://gvisor.dev/profile/gvisor-buildkite/${BUILDKITE_BUILD_ID}/${BUILDKITE_JOB_ID}/${file#/}/"
+ buildkite-agent artifact upload "${file}"
+ echo "<li><a href='${profile_url}'>${profile_name}</a></li>" >> "${profile_output}"
+done
+
+# Upload if we had outputs.
+if test -s "${profile_output}"; then
+ # Make the list a collapsible section in markdown.
+ sed -i "1s|^|<details><summary>${BUILDKITE_LABEL}</summary><ul>\n|" "${profile_output}"
+ echo "</ul></details>" >> "${profile_output}"
+ cat "${profile_output}" | buildkite-agent annotate --style info --context profiles --append
+fi
+rm -rf "${profile_output}"
+
+# Clean the bazel cache, if there's failure.
+if test "${BUILDKITE_COMMAND_EXIT_STATUS}" -ne "0"; then
+ # Attempt to clear the cache and shut down.
+ make clean || echo "make clean failed with code $?"
+ make bazel-shutdown || echo "make bazel-shutdown failed with code $?"
+fi
+
+# Kill any running containers (clear state).
+CONTAINERS="$(docker ps -q)"
+if ! test -z "${CONTAINERS}"; then
+ docker container kill ${CONTAINERS} 2>/dev/null || true
+fi
diff --git a/.buildkite/hooks/pre-command b/.buildkite/hooks/pre-command
new file mode 100644
index 000000000..ba688f9ac
--- /dev/null
+++ b/.buildkite/hooks/pre-command
@@ -0,0 +1,33 @@
+# Install packages we need. Docker must be installed and configured,
+# as should Go itself. We just install some extra bits and pieces.
+function install_pkgs() {
+ while true; do
+ if sudo apt-get update && sudo apt-get install -y "$@"; then
+ break
+ fi
+ done
+}
+install_pkgs graphviz jq curl binutils gnupg gnupg-agent linux-libc-dev \
+ apt-transport-https ca-certificates software-properties-common
+
+# Setup for parallelization with PARTITION and TOTAL_PARTITIONS.
+export PARTITION=${BUILDKITE_PARALLEL_JOB:-0}
+PARTITION=$((${PARTITION}+1)) # 1-indexed, but PARALLEL_JOB is 0-indexed.
+export TOTAL_PARTITIONS=${BUILDKITE_PARALLEL_JOB_COUNT:-1}
+
+# Ensure Docker has experimental enabled.
+EXPERIMENTAL=$(sudo docker version --format='{{.Server.Experimental}}')
+if test "${EXPERIMENTAL}" != "true"; then
+ make sudo TARGETS=//runsc:runsc ARGS="install --experimental=true"
+ sudo systemctl restart docker
+fi
+
+# Helper for benchmarks, based on the branch.
+if test "${BUILDKITE_BRANCH}" = "master"; then
+ export BENCHMARKS_OFFICIAL=true
+else
+ export BENCHMARKS_OFFICIAL=false
+fi
+
+# Clear existing profiles.
+sudo rm -rf /tmp/profile
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
new file mode 100644
index 000000000..34670f58d
--- /dev/null
+++ b/.buildkite/pipeline.yaml
@@ -0,0 +1,204 @@
+_templates:
+ common: &common
+ timeout_in_minutes: 30
+ retry:
+ automatic:
+ - exit_status: -1
+ limit: 10
+ - exit_status: "*"
+ limit: 2
+ benchmarks: &benchmarks
+ timeout_in_minutes: 120
+ retry:
+ automatic: false
+ soft_fail: true
+ if: build.branch == "master"
+ env:
+ # BENCHMARKS_OFFICIAL is set from hooks/pre-command, based
+ # on whether this is executing on the master branch.
+ BENCHMARKS_DATASET: buildkite
+ BENCHMARKS_PLATFORMS: "ptrace kvm"
+ BENCHMARKS_PROJECT: gvisor-benchmarks
+ BENCHMARKS_TABLE: benchmarks
+ BENCHMARKS_UPLOAD: true
+
+steps:
+ # Run basic smoke tests before preceding to other tests.
+ - <<: *common
+ label: ":fire: Smoke tests"
+ command: make smoke-tests
+ - wait
+
+ # Check that the Go branch builds.
+ - <<: *common
+ label: ":golang: Go branch"
+ commands:
+ - make go
+ - git checkout go && git clean -f
+ - go build ./...
+
+ # Release workflow.
+ - <<: *common
+ label: ":ship: Release tests"
+ commands: make release
+
+ # Basic unit tests.
+ - <<: *common
+ label: ":test_tube: Unit tests"
+ command: make unit-tests
+
+ # All system call tests.
+ - <<: *common
+ label: ":toolbox: System call tests"
+ command: make syscall-tests
+ parallelism: 20
+
+ # Integration tests.
+ - <<: *common
+ label: ":parachute: FUSE tests"
+ command: make fuse-tests
+ - <<: *common
+ label: ":docker: Docker tests"
+ command: make docker-tests
+ - <<: *common
+ label: ":goggles: Overlay tests"
+ command: make overlay-tests
+ - <<: *common
+ label: ":safety_pin: Host network tests"
+ command: make hostnet-tests
+ - <<: *common
+ label: ":satellite: SWGSO tests"
+ command: make swgso-tests
+ - <<: *common
+ label: ":coffee: Do tests"
+ command: make do-tests
+ - <<: *common
+ label: ":person_in_lotus_position: KVM tests"
+ command: make kvm-tests
+ - <<: *common
+ label: ":docker: Containerd 1.3.9 tests"
+ command: make containerd-test-1.3.9
+ - <<: *common
+ label: ":docker: Containerd 1.4.3 tests"
+ command: make containerd-test-1.4.3
+
+ # Check the website builds.
+ - <<: *common
+ label: ":earth_americas: Website tests"
+ command: make website-build
+
+ # Networking tests.
+ - <<: *common
+ label: ":table_tennis_paddle_and_ball: IPTables tests"
+ command: make iptables-tests
+ - <<: *common
+ label: ":construction_worker: Packetdrill tests"
+ command: make packetdrill-tests
+ - <<: *common
+ label: ":hammer: Packetimpact tests"
+ command: make packetimpact-tests
+
+ # Runtime tests.
+ - <<: *common
+ label: ":php: PHP runtime tests"
+ command: make php7.3.6-runtime-tests_vfs2
+ parallelism: 10
+ - <<: *common
+ label: ":java: Java runtime tests"
+ command: make java11-runtime-tests_vfs2
+ parallelism: 40
+ - <<: *common
+ label: ":golang: Go runtime tests"
+ command: make go1.12-runtime-tests_vfs2
+ parallelism: 10
+ - <<: *common
+ label: ":node: NodeJS runtime tests"
+ command: make nodejs12.4.0-runtime-tests_vfs2
+ parallelism: 10
+ - <<: *common
+ label: ":python: Python runtime tests"
+ command: make python3.7.3-runtime-tests_vfs2
+ parallelism: 10
+
+ # Runtime tests (VFS1).
+ - <<: *common
+ label: ":php: PHP runtime tests (VFS1)"
+ command: make php7.3.6-runtime-tests
+ parallelism: 10
+ if: build.message =~ /VFS1/ || build.branch == "master"
+ - <<: *common
+ label: ":java: Java runtime tests (VFS1)"
+ command: make java11-runtime-tests
+ parallelism: 40
+ if: build.message =~ /VFS1/ || build.branch == "master"
+ - <<: *common
+ label: ":golang: Go runtime tests (VFS1)"
+ command: make go1.12-runtime-tests
+ parallelism: 10
+ if: build.message =~ /VFS1/ || build.branch == "master"
+ - <<: *common
+ label: ":node: NodeJS runtime tests (VFS1)"
+ command: make nodejs12.4.0-runtime-tests
+ parallelism: 10
+ if: build.message =~ /VFS1/ || build.branch == "master"
+ - <<: *common
+ label: ":python: Python runtime tests (VFS1)"
+ command: make python3.7.3-runtime-tests
+ parallelism: 10
+ if: build.message =~ /VFS1/ || build.branch == "master"
+
+ # ARM tests.
+ - <<: *common
+ label: ":mechanical_arm: ARM"
+ command: make arm-qemu-smoke-test
+
+ # Run basic benchmarks smoke tests (no upload).
+ - <<: *common
+ label: ":fire: Benchmarks smoke test"
+ command: make benchmark-platforms
+ # Use the opposite of the benchmarks filter.
+ if: build.branch != "master"
+
+ # Run all benchmarks.
+ - <<: *benchmarks
+ label: ":bazel: ABSL build benchmarks"
+ command: make benchmark-platforms BENCHMARKS_FILTER="ABSL/page_cache.clean" BENCHMARKS_SUITE=absl BENCHMARKS_TARGETS=test/benchmarks/fs:bazel_test
+ - <<: *benchmarks
+ label: ":go: runsc build benchmarks"
+ command: make benchmark-platforms BENCHMARKS_FILTER="Runsc/page_cache.clean/filesystem.bind" BENCHMARKS_SUITE=runsc BENCHMARKS_TARGETS=test/benchmarks/fs:bazel_test
+ - <<: *benchmarks
+ label: ":metal: FFMPEG benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=ffmpeg BENCHMARKS_TARGETS=test/benchmarks/media:ffmpeg_test
+ - <<: *benchmarks
+ label: ":floppy_disk: FIO benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test
+ - <<: *benchmarks
+ label: ":globe_with_meridians: HTTPD benchmarks"
+ command: make benchmark-platforms BENCHMARKS_FILTER="Continuous" BENCHMARKS_SUITE=httpd BENCHMARKS_TARGETS=test/benchmarks/network:httpd_test
+ - <<: *benchmarks
+ label: ":piedpiper: iperf benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=iperf BENCHMARKS_TARGETS=test/benchmarks/network:iperf_test
+ - <<: *benchmarks
+ label: ":nginx: nginx benchmarks"
+ command: make benchmark-platforms BENCHMARKS_FILTER="Continuous" BENCHMARKS_SUITE=nginx BENCHMARKS_TARGETS=test/benchmarks/network:nginx_test
+ - <<: *benchmarks
+ label: ":node: node benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=node BENCHMARKS_TARGETS=test/benchmarks/network:node_test
+ - <<: *benchmarks
+ label: ":redis: Redis benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=redis BENCHMARKS_TARGETS=test/benchmarks/database:redis_test
+ - <<: *benchmarks
+ label: ":ruby: Ruby benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=ruby BENCHMARKS_TARGETS=test/benchmarks/network:ruby_test
+ - <<: *benchmarks
+ label: ":weight_lifter: Size benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=size BENCHMARKS_TARGETS=test/benchmarks/base:size_test
+ - <<: *benchmarks
+ label: ":speedboat: Startup benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=startup BENCHMARKS_TARGETS=test/benchmarks/base:startup_test
+ - <<: *benchmarks
+ label: ":computer: sysbench benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=sysbench BENCHMARKS_TARGETS=test/benchmarks/base:sysbench_test
+ - <<: *benchmarks
+ label: ":tensorflow: TensorFlow benchmarks"
+ command: make benchmark-platforms BENCHMARKS_SUITE=tensorflow BENCHMARKS_TARGETS=test/benchmarks/ml:tensorflow_test
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
deleted file mode 100644
index 264b4e9fa..000000000
--- a/.github/pull_request_template.md
+++ /dev/null
@@ -1,5 +0,0 @@
-* [ ] Have you followed the guidelines in [CONTRIBUTING.md](../blob/master/CONTRIBUTING.md)?
-* [ ] Have you formatted and linted your code?
-* [ ] Have you added relevant tests?
-* [ ] Have you added appropriate Fixes & Updates references?
-* [ ] If yes, please erase all these lines!
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index e28e46352..270aaf034 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,13 +1,15 @@
+# This workflow builds the source code, extracts nogo annotations and
+# posts them to GitHub, if applicable. This leverages the fact that the
+# workflow token has appropriate permissions to do so, and attempts to
+# leverage the GitHub workflow caches.
name: "Build"
-on:
+"on":
push:
branches:
- master
- - feature/**
pull_request:
branches:
- - master
- - feature/**
+ - "**"
jobs:
default:
@@ -22,7 +24,7 @@ jobs:
${{ runner.os }}-bazel-
- run: make
- run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..."
- - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ nogo"
+ - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ -path=bazel-out/ nogo"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REPOSITORY: ${{ github.repository }}
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 3a6a592d1..e62991691 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -1,12 +1,12 @@
+# This workflow generates the Go branch. Note that this does not test the Go
+# branch, as this is rolled into the main continuous integration pipeline. This
+# workflow simply generates and pushes the branch, as long as appropriate
+# permissions are available.
name: "Go"
-on:
+"on":
push:
branches:
- master
- pull_request:
- branches:
- - master
- - feature/**
jobs:
generate:
@@ -19,20 +19,13 @@ jobs:
else
echo ::set-output name=has_token::false
fi
- - run: |
- jq -nc '{"state": "pending", "context": "go tests"}' | \
- curl -sL -X POST -d @- \
- -H "Content-Type: application/json" \
- -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
- "${{ github.event.pull_request.statuses_url }}"
- if: github.event_name == 'pull_request'
- uses: actions/checkout@v2
- if: github.event_name == 'push' && steps.setup.outputs.has_token == 'true'
+ if: steps.setup.outputs.has_token == 'true'
with:
fetch-depth: 0
token: '${{ secrets.GO_TOKEN }}'
- uses: actions/checkout@v2
- if: github.event_name == 'pull_request' || steps.setup.outputs.has_token != 'true'
+ if: steps.setup.outputs.has_token != 'true'
with:
fetch-depth: 0
- uses: actions/setup-go@v2
@@ -50,32 +43,7 @@ jobs:
key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }}
restore-keys: |
${{ runner.os }}-bazel-
- # Create gopath to merge the changes. The first execution will create
- # symlinks to the cache, e.g. bazel-bin. Once the cache is setup, delete
- # old gopath files that may exist from previous runs (and could contain
- # files that are now deleted). Then run gopath again for good.
+ - run: make go
- run: |
- make build TARGETS="//:gopath"
- rm -rf bazel-bin/gopath
- make build TARGETS="//:gopath"
- - run: tools/go_branch.sh
- - run: git checkout go && git clean -f
- - run: go build ./...
- - if: github.event_name == 'push'
- run: |
git remote add upstream "https://github.com/${{ github.repository }}"
git push upstream go:go
- - if: ${{ success() && github.event_name == 'pull_request' }}
- run: |
- jq -nc '{"state": "success", "context": "go tests"}' | \
- curl -sL -X POST -d @- \
- -H "Content-Type: application/json" \
- -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
- "${{ github.event.pull_request.statuses_url }}"
- - if: ${{ failure() && github.event_name == 'pull_request' }}
- run: |
- jq -nc '{"state": "failure", "context": "go tests"}' | \
- curl -sL -X POST -d @- \
- -H "Content-Type: application/json" \
- -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
- "${{ github.event.pull_request.statuses_url }}"
diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml
index c53185620..3bd883035 100644
--- a/.github/workflows/issue_reviver.yml
+++ b/.github/workflows/issue_reviver.yml
@@ -1,5 +1,7 @@
+# This workflow revives issues that are still referenced in the code, and may
+# have been accidentally closed or marked stale.
name: "Issue reviver"
-on:
+"on":
schedule:
- cron: '0 0 * * *'
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index c09f7eb36..3a19065e1 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -1,5 +1,6 @@
+# Labeler labels incoming pull requests.
name: "Labeler"
-on:
+"on":
- pull_request
jobs:
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 0b31fecf5..3a4aa22e2 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -1,5 +1,7 @@
+# The stale workflow closes stale issues and pull requests, unless specific
+# tags have been applied in order to keep them open.
name: "Close stale issues"
-on:
+"on":
schedule:
- cron: "0 0 * * *"
diff --git a/.gitignore b/.gitignore
index 95fe857dd..a2a3fd508 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,7 @@
# Generated bazel symlinks.
/bazel-*
# Generated build event file.
-/.build_events.json \ No newline at end of file
+/.build_events.json
+# Generated repository.
+/repo
+/repo.key \ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 2d9fa80a1..000000000
--- a/.travis.yml
+++ /dev/null
@@ -1,47 +0,0 @@
-language: shell
-dist: xenial
-git:
- clone: false # Clone manually in before_install
-before_install:
- - set -e -o pipefail
- - |
- if [ "${TRAVIS_PULL_REQUEST}" = false ]; then
- # This is not a PR build, fetch and checkout the commit being tested
- git clone -q --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}"
- cd "${TRAVIS_REPO_SLUG}"
- git fetch origin "${TRAVIS_COMMIT}" --depth 1
- git checkout -qf "${TRAVIS_COMMIT}"
- else
- # This is a PR build, simulate +refs/pull/{num}/merge.
- # We can do that by fetching +refs/pull/{num}/head and cherry picking it
- # onto the target branch.
- git clone -q --branch "${TRAVIS_BRANCH}" --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}"
- cd "${TRAVIS_REPO_SLUG}"
- git fetch origin "+refs/pull/${TRAVIS_PULL_REQUEST}/head" --depth 1
- git config --global user.email "$(git log -1 FETCH_HEAD --pretty="%cE")"
- git config --global user.name "$(git log -1 FETCH_HEAD --pretty="%aN")"
- git cherry-pick --strategy=recursive -X theirs --keep-redundant-commits FETCH_HEAD
- fi
-cache:
- directories:
- - /home/travis/.cache/bazel/
-os: linux
-services:
- - docker
-jobs:
- include:
- # AMD64 builds are tested on kokoro, so don't run them in travis to save
- # capacity for arm64 builds.
- # - os: linux
- # arch: amd64
- - os: linux
- arch: arm64
-script:
- # On arm64, we need to create our own pipes for stderr and stdout,
- # otherwise we will not be able to open /dev/stderr. This is probably
- # due to AppArmor rules.
- - bash -xeo pipefail -c 'uname -a && make smoke-tests 2>&1 | cat'
-branches:
- except:
- # Skip copybara branches.
- - /^test\/cl.*$/
diff --git a/BUILD b/BUILD
index 0791f9fb4..d19d19866 100644
--- a/BUILD
+++ b/BUILD
@@ -1,5 +1,6 @@
load("//tools:defs.bzl", "build_test", "gazelle", "go_path")
load("//tools/nogo:defs.bzl", "nogo_config")
+load("//tools/yamltest:defs.bzl", "yaml_test")
load("//website:defs.bzl", "doc")
package(licenses = ["notice"])
@@ -50,6 +51,24 @@ doc(
weight = "99",
)
+yaml_test(
+ name = "nogo_config_test",
+ srcs = glob(["nogo*.yaml"]),
+ schema = "//tools/nogo:config-schema.json",
+)
+
+yaml_test(
+ name = "github_workflows_test",
+ srcs = glob([".github/workflows/*.yml"]),
+ schema = "@github_workflow_schema//file",
+)
+
+yaml_test(
+ name = "buildkite_pipelines_test",
+ srcs = glob([".buildkite/*.yaml"]),
+ schema = "@buildkite_pipeline_schema//file",
+)
+
# The sandbox filegroup is used for sandbox-internal dependencies.
package_group(
name = "sandbox",
@@ -67,12 +86,15 @@ build_test(
"//test/benchmarks/base:startup_test",
"//test/benchmarks/base:size_test",
"//test/benchmarks/base:sysbench_test",
- "//test/benchmarks/database:database_test",
+ "//test/benchmarks/database:redis_test",
"//test/benchmarks/fs:bazel_test",
"//test/benchmarks/fs:fio_test",
- "//test/benchmarks/media:media_test",
- "//test/benchmarks/ml:ml_test",
- "//test/benchmarks/network:network_test",
+ "//test/benchmarks/media:ffmpeg_test",
+ "//test/benchmarks/ml:tensorflow_test",
+ "//test/benchmarks/network:httpd_test",
+ "//test/benchmarks/network:nginx_test",
+ "//test/benchmarks/network:node_test",
+ "//test/benchmarks/network:ruby_test",
],
)
@@ -102,7 +124,9 @@ go_path(
"//pkg/sentry/kernel/memevent",
"//pkg/tcpip/adapters/gonet",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/muxed",
+ "//pkg/tcpip/link/pipe",
"//pkg/tcpip/link/sharedmem",
"//pkg/tcpip/link/sharedmem/pipe",
"//pkg/tcpip/link/sharedmem/queue",
diff --git a/Makefile b/Makefile
index 47d89c438..4eb85f5af 100644
--- a/Makefile
+++ b/Makefile
@@ -14,27 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Helpful pretty-printer.
-ifeq (0,$(MAKELEVEL))
-OPENLAST := || (rc=$$?; echo '^^^ +++' >&2; exit $$rc)
-else
-OPENLAST :=
-endif
-CMDLINE := $(shell cut -d '' -f2- /proc/$$PPID/cmdline | sed 's|\x00| |g')
-submake = echo '--- make $1' >&2 && \
- $(MAKE) -s $1 && \
- echo '--- make $(CMDLINE) (resume)' >&2 \
- $(OPENLAST)
-
-# Described below.
-OPTIONS :=
-STARTUP_OPTIONS :=
-TARGETS := //runsc
-ARGS :=
-
default: runsc
.PHONY: default
+# Header for debugging (used by other macros).
+header = echo --- $(1) >&2
+
+# Make hacks.
+EMPTY :=
+SPACE := $(EMPTY) $(EMPTY)
+SHELL = /bin/bash
+
## usage: make <target>
## or
## make <build|test|copy|run|sudo> STARTUP_OPTIONS="..." OPTIONS="..." TARGETS="..." ARGS="..."
@@ -46,7 +36,6 @@ default: runsc
## requirements.
##
## There are common arguments that may be passed to targets. These are:
-## STARTUP_OPTIONS - Bazel startup options.
## OPTIONS - Build or test options.
## TARGETS - The bazel targets.
## ARGS - Arguments for run or sudo.
@@ -57,7 +46,7 @@ default: runsc
## make build OPTIONS="" TARGETS="//runsc"'
##
help: ## Shows all targets and help from the Makefile (this message).
- @grep --no-filename -E '^([a-z.A-Z_-]+:.*?|)##' $(MAKEFILE_LIST) | \
+ @grep --no-filename -E '^([a-z.A-Z_%-]+:.*?|)##' $(MAKEFILE_LIST) | \
awk 'BEGIN {FS = "(:.*?|)## ?"}; { \
if (length($$1) > 0) { \
printf " \033[36m%-20s\033[0m %s\n", $$1, $$2; \
@@ -65,17 +54,34 @@ help: ## Shows all targets and help from the Makefile (this message).
printf "%s\n", $$2; \
} \
}'
+
build: ## Builds the given $(TARGETS) with the given $(OPTIONS). E.g. make build TARGETS=runsc
-test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test
-copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp
-run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version
-sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v
-.PHONY: help build test copy run sudo
+ @$(call build,$(OPTIONS) $(TARGETS))
+.PHONY: build
+
+test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test
+ @$(call test,$(OPTIONS) $(TARGETS))
+.PHONY: test
+
+copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp
+ @$(call copy,$(TARGETS),$(DESTINATION))
+.PHONY: copy
+
+run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version
+ @$(call run,$(TARGETS),$(ARGS))
+.PHONY: run
+
+sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v
+ @$(call sudo,$(TARGETS),$(ARGS))
+.PHONY: sudo
+
+# Load image helpers.
+include tools/images.mk
# Load all bazel wrappers.
#
# This file should define the basic "build", "test", "run" and "sudo" rules, in
-# addition to the $(BRANCH_NAME) variable.
+# addition to the $(BRANCH_NAME) and $(BUILD_ROOTS) variables.
ifneq (,$(wildcard tools/google.mk))
include tools/google.mk
else
@@ -83,32 +89,76 @@ include tools/bazel.mk
endif
##
-## Docker image targets.
-##
-## Images used by the tests must also be built and available locally.
-## The canonical test targets defined below will automatically load
-## relevant images. These can be loaded or built manually via these
-## targets.
+## Development helpers and tooling.
##
-## (*) Note that you may provide an ARCH parameter in order to build
-## and load images from an alternate archiecture (using qemu). When
-## bazel is run as a server, this has the effect of running an full
-## cross-architecture chain, and can produce cross-compiled binaries.
+## These targets faciliate local development by automatically
+## installing and configuring a runtime. Several variables may
+## be used here to tweak the installation:
+## RUNTIME - The name of the installed runtime (default: branch).
+## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME).
+## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc).
+## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs).
+## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%).
##
-define images
-$(1)-%: ## Image tool: $(1) a given image (also may use 'all-images').
- @$(call submake,-C images $$@)
-endef
-rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'.
-$(eval $(call images,rebuild))
-push-...: ## Push the given image. Also may use 'push-all-images'.
-$(eval $(call images,push))
-pull-...: ## Pull the given image. Also may use 'pull-all-images'.
-$(eval $(call images,pull))
-load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'.
-$(eval $(call images,load))
-list-images: ## List all available images.
- @$(call submake, -C images $$@)
+ifeq (,$(BRANCH_NAME))
+RUNTIME := runsc
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
+else
+RUNTIME := $(BRANCH_NAME)
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
+endif
+RUNTIME_BIN := $(RUNTIME_DIR)/runsc
+RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs
+RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+
+$(RUNTIME_BIN): # See below.
+ @mkdir -p "$(RUNTIME_DIR)"
+ @$(call copy,//runsc,$(RUNTIME_BIN))
+.PHONY: $(RUNTIME_BIN) # Real file, but force rebuild.
+
+# Configure helpers for below.
+configure_noreload = \
+ $(call header,CONFIGURE $(1) → $(RUNTIME_BIN) $(2)); \
+ sudo $(RUNTIME_BIN) install --experimental=true --runtime="$(1)" -- --debug-log "$(RUNTIME_LOGS)" $(2) && \
+ sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)"
+reload_docker = \
+ sudo systemctl reload docker && \
+ if test -f /etc/docker/daemon.json; then \
+ sudo chmod 0755 /etc/docker && \
+ sudo chmod 0644 /etc/docker/daemon.json; \
+ fi
+configure = $(call configure_noreload,$(1),$(2)) && $(reload_docker)
+
+# Helpers for above. Requires $(RUNTIME_BIN) dependency.
+install_runtime = $(call configure,$(1),$(2) --TESTONLY-test-name-env=RUNSC_TEST_NAME)
+# Don't use cached results, otherwise multiple runs using different runtimes
+# may be skipped, if all other inputs are the same.
+test_runtime = $(call test,--test_arg=--runtime=$(1) --nocache_test_results $(PARTITIONS) $(2))
+
+refresh: $(RUNTIME_BIN) ## Updates the runtime binary.
+.PHONY: refresh
+
+dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo.
+ @$(call configure_noreload,$(RUNTIME),--net-raw)
+ @$(call configure_noreload,$(RUNTIME)-d,--net-raw --debug --strace --log-packets)
+ @$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile)
+ @$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2)
+ @$(call reload_docker)
+.PHONY: dev
+
+nogo: ## Surfaces all nogo findings.
+ @$(call build,--build_tag_filters nogo //...)
+ @$(call run,//tools/github $(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo)
+.PHONY: nogo
+
+go: ## Builds the Go branch.
+ @$(call clean)
+ @$(call build,//:gopath)
+ @tools/go_branch.sh
+
+gazelle: ## Runs gazelle to update WORKSPACE.
+ @$(call run,//:gazelle update-repos -from_file=go.mod -prune)
+.PHONY: gazelle
##
## Canonical build and test targets.
@@ -123,25 +173,26 @@ list-images: ## List all available images.
##
PARTITION ?= 1
TOTAL_PARTITIONS ?= 1
+PARTITIONS := --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)
runsc: ## Builds the runsc binary.
- @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc")
+ @$(call build,-c opt //runsc)
.PHONY: runsc
debian: ## Builds the debian packages.
- @$(call submake,build OPTIONS="-c opt" TARGETS="//debian:debian")
+ @$(call build,-c opt //debian:debian)
.PHONY: debian
smoke-tests: ## Runs a simple smoke test after build runsc.
- @$(call submake,run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true")
+ @$(call run,//runsc,--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true)
.PHONY: smoke-tests
fuse-tests:
- @$(call submake,test OPTIONS="--test_tag_filters fuse" TARGETS="test/fuse/...")
+ @$(call test,--test_tag_filters=fuse $(PARTITIONS) test/fuse/...)
.PHONY: fuse-tests
unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc.
- @$(call submake,test TARGETS="pkg/... runsc/... tools/...")
+ @$(call test,//:all pkg/... runsc/... tools/...)
.PHONY: unit-tests
tests: ## Runs all unit tests and syscall tests.
@@ -150,114 +201,107 @@ tests: unit-tests syscall-tests
integration-tests: ## Run all standard integration tests.
integration-tests: docker-tests overlay-tests hostnet-tests swgso-tests
-integration-tests: do-tests kvm-tests containerd-test-1.3.4
+integration-tests: do-tests kvm-tests containerd-test-1.3.9
.PHONY: integration-tests
network-tests: ## Run all networking integration tests.
network-tests: iptables-tests packetdrill-tests packetimpact-tests
.PHONY: network-tests
-# Standard integration targets.
-INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test
-
syscall-%-tests:
- @$(call submake,test OPTIONS="--test_tag_filters runsc_$*" TARGETS="test/syscalls/...")
+ @$(call test,--test_tag_filters=runsc_$* $(PARTITIONS) test/syscalls/...)
syscall-native-tests:
- @$(call submake,test OPTIONS="--test_tag_filters native" TARGETS="test/syscalls/...")
+ @$(call test,--test_tag_filters=native $(PARTITIONS) test/syscalls/...)
.PHONY: syscall-native-tests
syscall-tests: ## Run all system call tests.
- @$(call submake,test TARGETS="test/syscalls/...")
+ @$(call test,$(PARTITIONS) test/syscalls/...)
-%-runtime-tests: load-runtimes_%
- @$(call submake,install-runtime)
- @$(call submake,test-runtime OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*")
+%-runtime-tests: load-runtimes_% $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),) # Ensure flags are cleared.
+ @$(call test_runtime,$(RUNTIME),--test_timeout=10800 //test/runtimes:$*)
-%-runtime-tests_vfs2: load-runtimes_%
-ifeq ($(PARTITION),)
- @$(eval PARTITION := 1)
-endif
-ifeq ($(TOTAL_PARTITIONS),)
- @$(eval TOTAL_PARTITIONS := 1)
-endif
- @$(call submake,install-runtime RUNTIME="vfs2" ARGS="--vfs2")
- @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*")
+%-runtime-tests_vfs2: load-runtimes_% $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),--vfs2)
+ @$(call test_runtime,$(RUNTIME),--test_timeout=10800 //test/runtimes:$*)
-do-tests: runsc
- @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true")
- @$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true")
- @$(call submake,sudo TARGETS="//runsc" ARGS="do true")
+do-tests:
+ @$(call run,//runsc,--rootless do true)
+ @$(call run,//runsc,--rootless -network=none do true)
+ @$(call sudo,//runsc,do true)
.PHONY: do-tests
+arm-qemu-smoke-test: BAZEL_OPTIONS=--config=cross-aarch64
+arm-qemu-smoke-test: load-arm-qemu
+ export T=$$(mktemp -d --tmpdir release.XXXXXX); \
+ mkdir -p $$T/bin/arm64/ && \
+ $(call copy,//runsc:runsc,$$T/bin/arm64) && \
+ docker run --rm -v $$T/bin/arm64/runsc:/workdir/initramfs/runsc gvisor.dev/images/arm-qemu
+.PHONY: arm-qemu-smoke-test
+
simple-tests: unit-tests # Compatibility target.
.PHONY: simple-tests
-docker-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="vfs1")
- @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)")
- @$(call submake,install-runtime RUNTIME="vfs2" ARGS="--vfs2")
- @$(call submake,test-runtime RUNTIME="vfs2" TARGETS="$(INTEGRATION_TARGETS)")
+# Standard integration targets.
+INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test
+
+docker-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),) # Clear flags.
+ @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS))
+ @$(call install_runtime,$(RUNTIME),--vfs2)
+ @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS))
.PHONY: docker-tests
-overlay-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="overlay" ARGS="--overlay")
- @$(call submake,test-runtime RUNTIME="overlay" TARGETS="$(INTEGRATION_TARGETS)")
+overlay-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),--overlay)
+ @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS))
.PHONY: overlay-tests
-swgso-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false")
- @$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)")
+swgso-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),--software-gso=true --gso=false)
+ @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS))
.PHONY: swgso-tests
-hostnet-tests: load-basic-images
- @$(call submake,install-runtime RUNTIME="hostnet" ARGS="--network=host")
- @$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false" TARGETS="$(INTEGRATION_TARGETS)")
+hostnet-tests: load-basic $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),--network=host)
+ @$(call test_runtime,$(RUNTIME),--test_arg=-checkpoint=false --test_arg=-hostnet=true $(INTEGRATION_TARGETS))
.PHONY: hostnet-tests
-kvm-tests: load-basic-images
+kvm-tests: load-basic $(RUNTIME_BIN)
@(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
- @if ! [[ -w /dev/kvm ]]; then sudo chmod a+rw /dev/kvm; fi
- @$(call submake,test TARGETS="//pkg/sentry/platform/kvm:kvm_test")
- @$(call submake,install-runtime RUNTIME="kvm" ARGS="--platform=kvm")
- @$(call submake,test-runtime RUNTIME="kvm" TARGETS="$(INTEGRATION_TARGETS)")
+ @if ! test -w /dev/kvm; then sudo chmod a+rw /dev/kvm; fi
+ @$(call test,//pkg/sentry/platform/kvm:kvm_test)
+ @$(call install_runtime,$(RUNTIME),--platform=kvm)
+ @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS))
.PHONY: kvm-tests
-iptables-tests: load-iptables
+iptables-tests: load-iptables $(RUNTIME_BIN)
@sudo modprobe iptable_filter
@sudo modprobe ip6table_filter
- @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test")
- @$(call submake,install-runtime RUNTIME="iptables" ARGS="--net-raw")
- @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
+ @$(call test,--test_arg=-runtime=runc $(PARTITIONS) //test/iptables:iptables_test)
+ @$(call install_runtime,$(RUNTIME),--net-raw)
+ @$(call test_runtime,$(RUNTIME),//test/iptables:iptables_test)
.PHONY: iptables-tests
-# Run the iptables tests with runsc only. Useful for developing to skip runc
-# testing.
-iptables-runsc-tests: load-iptables
- @sudo modprobe iptable_filter
- @sudo modprobe ip6table_filter
- @$(call submake,install-runtime RUNTIME="iptables" ARGS="--net-raw")
- @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
-.PHONY: iptables-runsc-tests
-
-packetdrill-tests: load-packetdrill
- @$(call submake,install-runtime RUNTIME="packetdrill")
- @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) -s query TARGETS='attr(tags, packetdrill, tests(//...))')")
+packetdrill-tests: load-packetdrill $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),) # Clear flags.
+ @$(call test_runtime,$(RUNTIME),//test/packetdrill:all_tests)
.PHONY: packetdrill-tests
-packetimpact-tests: load-packetimpact
+packetimpact-tests: load-packetimpact $(RUNTIME_BIN)
@sudo modprobe iptable_filter
@sudo modprobe ip6table_filter
- @$(call submake,install-runtime RUNTIME="packetimpact")
- @$(call submake,test-runtime OPTIONS="--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3" RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) -s query TARGETS='attr(tags, packetimpact, tests(//...))')")
+ @$(call install_runtime,$(RUNTIME),) # Clear flags.
+ @$(call test_runtime,$(RUNTIME),--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3 //test/packetimpact/tests:all_tests)
.PHONY: packetimpact-tests
# Specific containerd version tests.
-containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu
- @$(call submake,install-runtime RUNTIME="root")
- @CONTAINERD_VERSION=$* $(MAKE) -s sudo TARGETS="tools/installers:containerd"
- @$(MAKE) -s sudo TARGETS="tools/installers:shim"
- @$(MAKE) -s sudo TARGETS="test/root:root_test" ARGS="--runtime=root -test.v"
+containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu $(RUNTIME_BIN)
+ @$(call install_runtime,$(RUNTIME),) # Clear flags.
+ @$(call sudo,tools/installers:containerd,$*)
+ @$(call sudo,tools/installers:shim)
+ @$(call sudo,test/root:root_test,--runtime=$(RUNTIME) -test.v)
# Note that we can't run containerd-test-1.1.8 tests here.
#
@@ -266,8 +310,8 @@ containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-b
# actually drive the tests. The v1 API is tested exclusively through 1.2.13.
containerd-tests: ## Runs all supported containerd version tests.
containerd-tests: containerd-test-1.2.13
-containerd-tests: containerd-test-1.3.4
-containerd-tests: containerd-test-1.4.0-beta.0
+containerd-tests: containerd-test-1.3.9
+containerd-tests: containerd-test-1.4.3
##
## Benchmarks.
@@ -275,55 +319,54 @@ containerd-tests: containerd-test-1.4.0-beta.0
## Targets to run benchmarks. See //test/benchmarks for details.
##
## common arguments:
-## RUNTIME_ARGS - arguments to runsc placed in /etc/docker/daemon.json
-## e.g. "--platform=ptrace"
-## BENCHMARKS_PROJECT - BigQuery project to which to send data.
-## BENCHMARKS_DATASET - BigQuery dataset to which to send data.
-## BENCHMARKS_TABLE - BigQuery table to which to send data.
-## BENCHMARKS_SUITE - name of the benchmark suite. See //tools/bigquery/bigquery.go.
-## BENCHMARKS_UPLOAD - if true, upload benchmark data from the run.
-## BENCHMARKS_OFFICIAL - marks the data as official.
+## BENCHMARKS_PROJECT - BigQuery project to which to send data.
+## BENCHMARKS_DATASET - BigQuery dataset to which to send data.
+## BENCHMARKS_TABLE - BigQuery table to which to send data.
+## BENCHMARKS_SUITE - name of the benchmark suite. See //tools/bigquery/bigquery.go.
+## BENCHMARKS_UPLOAD - if true, upload benchmark data from the run.
+## BENCHMARKS_OFFICIAL - marks the data as official.
## BENCHMARKS_PLATFORMS - platforms to run benchmarks (e.g. ptrace kvm).
+## BENCHMARKS_FILTER - filter to be applied to the test suite.
+## BENCHMARKS_OPTIONS - options to be passed to the test.
##
-BENCHMARKS_PROJECT := gvisor-benchmarks
-BENCHMARKS_DATASET := kokoro
-BENCHMARKS_TABLE := benchmarks
-BENCHMARKS_SUITE := start
-BENCHMARKS_UPLOAD := false
-BENCHMARKS_OFFICIAL := false
-BENCHMARKS_PLATFORMS := ptrace
-BENCHMARKS_TARGETS := //test/benchmarks/base:startup_test
-BENCHMARKS_ARGS := -test.bench=. -pprof-cpu -pprof-heap -pprof-heap -pprof-block
-
-init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema
-## (see //tools/bigquery/bigquery.go). If the table alread exists, this is a noop.
- $(call submake, run TARGETS=//tools/parsers:parser ARGS="init --project=$(BENCHMARKS_PROJECT) \
- --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)")
+BENCHMARKS_PROJECT ?= gvisor-benchmarks
+BENCHMARKS_DATASET ?= kokoro
+BENCHMARKS_TABLE ?= benchmarks
+BENCHMARKS_SUITE ?= ffmpeg
+BENCHMARKS_UPLOAD ?= false
+BENCHMARKS_OFFICIAL ?= false
+BENCHMARKS_PLATFORMS ?= ptrace
+BENCHMARKS_TARGETS := //test/benchmarks/media:ffmpeg_test
+BENCHMARKS_FILTER := .
+BENCHMARKS_OPTIONS := -test.benchtime=30s
+BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex $(BENCHMARKS_OPTIONS)
+
+init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema.
+ @$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE))
.PHONY: init-benchmark-table
-benchmark-platforms: load-benchmarks-images ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
- $(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \
- $(call submake,run-benchmark RUNTIME="$(PLATFORM)" ARGS="--platform=$(PLATFORM) --vfs2") && \
- $(call submake,run-benchmark RUNTIME="$(PLATFORM)_vfs1" ARGS="--platform=$(PLATFORM)") && \
- ) \
- $(call submake, run-benchmark RUNTIME="runc")
+# $(1) is the runtime name, $(2) are the arguments.
+run_benchmark = \
+ ($(call header,BENCHMARK $(1) $(2)); \
+ set -euo pipefail; \
+ if test "$(1)" != "runc"; then $(call install_runtime,$(1),--profile $(2)); fi; \
+ export T=$$(mktemp --tmpdir logs.$(1).XXXXXX); \
+ $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS)) | tee $$T; \
+ if test "$(BENCHMARKS_UPLOAD)" = "true"; then \
+ $(call run,tools/parsers:parser,parse --debug --file=$$T --runtime=$(1) --suite_name=$(BENCHMARKS_SUITE) --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)); \
+ fi; \
+ rm -rf $$T)
+
+benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
+ @$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \
+ $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) --vfs2) && \
+ ) true
+ @$(call run_benchmark,runc)
.PHONY: benchmark-platforms
-run-benchmark: load-benchmarks-images ## Runs single benchmark and optionally sends data to BigQuery.
- @if [[ "$(RUNTIME)" != "runc" ]]; then $(call submake,install-runtime ARGS="$(ARGS) --profile"); fi
- @T=$$(mktemp --tmpdir logs.$(RUNTIME).XXXXXX); \
- $(call submake,sudo TARGETS="$(BENCHMARKS_TARGETS)" ARGS="--runtime=$(RUNTIME) $(BENCHMARKS_ARGS) | tee $$T"); \
- rc=$$?; \
- if [[ $$rc -eq 0 ]] && [[ "$(BENCHMARKS_UPLOAD)" == "true" ]]; then \
- $(call submake,run TARGETS="tools/parsers:parser" ARGS="parse --debug --file=$$T \
- --runtime=$(RUNTIME) --suite_name=$(BENCHMARKS_SUITE) \
- --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) \
- --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)"); \
- fi; \
- rm -rf $$T; \
- exit $$rc
+run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery.
+ @$(call run_benchmark,$(RUNTIME),)
.PHONY: run-benchmark
-.PHONY: load-benchmarks-images
##
## Website & documentation helpers.
@@ -342,7 +385,7 @@ WEBSITE_PROJECT := gvisordev
WEBSITE_REGION := us-central1
website-build: load-jekyll ## Build the site image locally.
- @$(call submake,run TARGETS="//website:website" ARGS="$(WEBSITE_IMAGE)")
+ @$(call run,//website:website,$(WEBSITE_IMAGE))
.PHONY: website-build
website-server: website-build ## Run a local server for development.
@@ -354,7 +397,7 @@ website-push: website-build ## Push a new image and update the service.
.PHONY: website-push
website-deploy: website-push ## Deploy a new version of the website.
- @gcloud run deploy $(WEBSITE_SERVICE) --platform=managed --region=$(WEBSITE_REGION) --project=$(WEBSITE_PROJECT) --image=$(WEBSITE_IMAGE)
+ @gcloud run deploy $(WEBSITE_SERVICE) --platform=managed --region=$(WEBSITE_REGION) --project=$(WEBSITE_PROJECT) --image=$(WEBSITE_IMAGE) --memory 1Gi
.PHONY: website-deploy
##
@@ -368,17 +411,17 @@ website-deploy: website-push ## Deploy a new version of the website.
## RELEASE_NAME - The name of the release in the proper format (needed for tag).
## RELEASE_NOTES - The file containing release notes (needed for tag).
##
-RELEASE_ROOT := $(CURDIR)/repo
-RELEASE_KEY := repo.key
-RELEASE_NIGHTLY := false
-RELEASE_COMMIT :=
-RELEASE_NAME :=
-RELEASE_NOTES :=
-
+RELEASE_ROOT := $(CURDIR)/repo
+RELEASE_KEY := repo.key
+RELEASE_NIGHTLY := false
+RELEASE_COMMIT :=
+RELEASE_NAME :=
+RELEASE_NOTES :=
GPG_TEST_OPTIONS := $(shell if gpg --pinentry-mode loopback --version >/dev/null 2>&1; then echo --pinentry-mode loopback; fi)
+
$(RELEASE_KEY):
@echo "WARNING: Generating a key for testing ($@); don't use this."
- T=$$(mktemp --tmpdir keyring.XXXXXX); \
+ @T=$$(mktemp --tmpdir keyring.XXXXXX); \
C=$$(mktemp --tmpdir config.XXXXXX); \
echo Key-Type: DSA >> $$C && \
echo Key-Length: 1024 >> $$C && \
@@ -392,11 +435,11 @@ $(RELEASE_KEY):
release: $(RELEASE_KEY) ## Builds a release.
@mkdir -p $(RELEASE_ROOT)
- @T=$$(mktemp -d --tmpdir release.XXXXXX); \
- $(call submake,copy TARGETS="//runsc:runsc" DESTINATION=$$T) && \
- $(call submake,copy TARGETS="//shim/v1:gvisor-containerd-shim" DESTINATION=$$T) && \
- $(call submake,copy TARGETS="//shim/v2:containerd-shim-runsc-v1" DESTINATION=$$T) && \
- $(call submake,copy TARGETS="//debian:debian" DESTINATION=$$T) && \
+ @export T=$$(mktemp -d --tmpdir release.XXXXXX); \
+ $(call copy,//runsc:runsc,$$T) && \
+ $(call copy,//shim/v1:gvisor-containerd-shim,$$T) && \
+ $(call copy,//shim/v2:containerd-shim-runsc-v1,$$T) && \
+ $(call copy,//debian:debian,$$T) && \
NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \
rc=$$?; rm -rf $$T; exit $$rc
.PHONY: release
@@ -404,74 +447,3 @@ release: $(RELEASE_KEY) ## Builds a release.
tag: ## Creates and pushes a release tag.
@tools/tag_release.sh "$(RELEASE_COMMIT)" "$(RELEASE_NAME)" "$(RELEASE_NOTES)"
.PHONY: tag
-
-##
-## Development helpers and tooling.
-##
-## These targets faciliate local development by automatically
-## installing and configuring a runtime. Several variables may
-## be used here to tweak the installation:
-## RUNTIME - The name of the installed runtime (default: branch).
-## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME).
-## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc).
-## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs).
-## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%).
-##
-ifeq (,$(BRANCH_NAME))
-RUNTIME := runsc
-RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
-else
-RUNTIME := $(BRANCH_NAME)
-RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
-endif
-RUNTIME_BIN := $(RUNTIME_DIR)/runsc
-RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs
-RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
-
-dev: ## Installs a set of local runtimes. Requires sudo.
- @$(call submake,refresh)
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="--net-raw")
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets")
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-p" ARGS="--net-raw --profile")
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2")
- @sudo systemctl restart docker
-.PHONY: dev
-
-refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'install-runtime' first.
- @mkdir -p "$(RUNTIME_DIR)"
- @$(call submake,copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)")
-.PHONY: refresh
-
-install-runtime: ## Installs the runtime for testing. Requires sudo.
- @$(call submake,refresh)
- @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="$(ARGS) --TESTONLY-test-name-env=RUNSC_TEST_NAME")
- @sudo systemctl restart docker
- @if [[ -f /etc/docker/daemon.json ]]; then \
- sudo chmod 0755 /etc/docker && \
- sudo chmod 0644 /etc/docker/daemon.json; \
- fi
-.PHONY: install-runtime
-
-install-debug-runtime: ## Installs the runtime for debugging. Requires sudo.
- @$(call submake,install-runtime ARGS="--debug --strace --log-packets $(ARGS)")
-.PHONY: install-debug-runtime
-
-configure: ## Configures a single runtime. Requires sudo. Typically called from dev or install-runtime.
- @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME_NAME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS)
- @echo -e "$(INFO) Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)"
- @echo -e "$(INFO) Logs are in: $(RUNTIME_LOG_DIR)"
- @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)"
-.PHONY: configure
-
-test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided.
- @$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)")
-.PHONY: test-runtime
-
-nogo: ## Surfaces all nogo findings.
- @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...")
- @$(call submake,run TARGETS="//tools/github" ARGS="$(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo")
-.PHONY: nogo
-
-gazelle: ## Runs gazelle to update WORKSPACE.
- @$(call submake,run TARGETS="//:gazelle" ARGS="update-repos -from_file=go.mod -prune")
-.PHONY: gazelle
diff --git a/README.md b/README.md
index 0a79e2cff..866a6a248 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
![gVisor](g3doc/logo.png)
-![](https://github.com/google/gvisor/workflows/Build/badge.svg)
+[![Build status](https://badge.buildkite.com/3b159f20b9830461a71112566c4171c0bdfd2f980a8e4c0ae6.svg?branch=master)](https://buildkite.com/gvisor/pipeline)
[![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community)
[![code search](https://img.shields.io/badge/code-search-blue)](https://cs.opensource.google/gvisor/gvisor)
diff --git a/WORKSPACE b/WORKSPACE
index 2405bfd80..f48f10e94 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,6 +1,17 @@
-load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
+# Root certificates.
+#
+# Note that the sha256 hash is ommitted here intentionally. This should not be
+# used in any part of the build other than as certificates present in images.
+http_file(
+ name = "google_root_pem",
+ urls = [
+ "https://pki.goog/roots.pem"
+ ],
+)
+
# Bazel/starlark utilities.
http_archive(
name = "bazel_skylib",
@@ -83,6 +94,20 @@ http_archive(
],
)
+# Load C++ cross-compilation toolchains.
+http_archive(
+ name = "coral_crosstool",
+ sha256 = "088ef98b19a45d7224be13636487e3af57b1564880b67df7be8b3b7eee4a1bfc",
+ strip_prefix = "crosstool-142e930ac6bf1295ff3ba7ba2b5b6324dfb42839",
+ urls = [
+ "https://github.com/google-coral/crosstool/archive/142e930ac6bf1295ff3ba7ba2b5b6324dfb42839.tar.gz",
+ ],
+)
+
+load("@coral_crosstool//:configure.bzl", "cc_crosstool")
+
+cc_crosstool(name = "crosstool")
+
# Load protobuf dependencies.
http_archive(
name = "rules_proto",
@@ -176,6 +201,19 @@ http_archive(
],
)
+# Schemas for testing.
+http_file(
+ name = "buildkite_pipeline_schema",
+ sha256 = "3369c58038b4d55c08928affafb653716eb1e7b3cabb4a391aef979dd921f4e1",
+ urls = ["https://raw.githubusercontent.com/buildkite/pipeline-schema/f7a0894074d194bcf19eec5411fec0528f7f4180/schema.json"],
+)
+
+http_file(
+ name = "github_workflow_schema",
+ sha256 = "60603d1095b11d136e04a8b95be83a23ad8044169e46f82f925c320c1cf47a49",
+ urls = ["https://raw.githubusercontent.com/SchemaStore/schemastore/27612065234778feaac216ce14dd47846fe0a2dd/src/schemas/json/github-workflow.json"],
+)
+
# External Go repositories.
#
# Unfortunately, gazelle will automatically parse go modules in the
@@ -193,8 +231,8 @@ go_repository(
name = "com_github_containerd_containerd",
build_file_proto_mode = "disable",
importpath = "github.com/containerd/containerd",
- sum = "h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=",
- version = "v1.3.4",
+ sum = "h1:K2U/F4jGAMBqeUssfgJRbFuomLcS2Fxo1vR3UM/Mbh8=",
+ version = "v1.3.9",
)
go_repository(
@@ -524,8 +562,8 @@ go_repository(
name = "com_github_containerd_cgroups",
build_file_proto_mode = "disable",
importpath = "github.com/containerd/cgroups",
- sum = "h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY=",
- version = "v0.0.0-20181219155423-39b18af02c41",
+ sum = "h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw=",
+ version = "v0.0.0-20201119153540-4cbc285b3327",
)
go_repository(
@@ -853,8 +891,8 @@ go_repository(
go_repository(
name = "com_github_google_pprof",
importpath = "github.com/google/pprof",
- sum = "h1:DLpL8pWq0v4JYoRpEhDfsJhhJyGKCcQM2WPW2TJs31c=",
- version = "v0.0.0-20191218002539-d4f498aebedc",
+ sum = "h1:LR89qFljJ48s990kEKGsk213yIJDPI4205OKOzbURK8=",
+ version = "v0.0.0-20201218002935-b9804c9f04c2",
)
go_repository(
@@ -1391,3 +1429,24 @@ go_repository(
sum = "h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE=",
version = "v0.0.0-20190801114015-581e00157fb1",
)
+
+go_repository(
+ name = "com_github_xeipuuv_gojsonpointer",
+ importpath = "github.com/xeipuuv/gojsonpointer",
+ sum = "h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=",
+ version = "v0.0.0-20190905194746-02993c407bfb",
+)
+
+go_repository(
+ name = "com_github_xeipuuv_gojsonreference",
+ importpath = "github.com/xeipuuv/gojsonreference",
+ sum = "h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=",
+ version = "v0.0.0-20180127040603-bd5ef7bd5415",
+)
+
+go_repository(
+ name = "com_github_xeipuuv_gojsonschema",
+ importpath = "github.com/xeipuuv/gojsonschema",
+ sum = "h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=",
+ version = "v1.2.0",
+)
diff --git a/g3doc/architecture_guide/performance.md b/g3doc/architecture_guide/performance.md
index b981f0c01..b89facfd3 100644
--- a/g3doc/architecture_guide/performance.md
+++ b/g3doc/architecture_guide/performance.md
@@ -269,7 +269,7 @@ operations are less of an issue. The above figure shows the total time required
for an `ffmpeg` container to start, load and transcode a 27MB input video.
[ab]: https://en.wikipedia.org/wiki/ApacheBench
-[benchmark-tools]: https://github.com/google/gvisor/tree/master/benchmarks
+[benchmark-tools]: https://github.com/google/gvisor/tree/master/test/benchmarks
[gce]: https://cloud.google.com/compute/
[cnn]: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/convolutional_network.py
[docker]: https://docker.io
diff --git a/g3doc/proposals/runtime_dedicate_os_thread.md b/g3doc/proposals/runtime_dedicate_os_thread.md
new file mode 100644
index 000000000..dc70055b0
--- /dev/null
+++ b/g3doc/proposals/runtime_dedicate_os_thread.md
@@ -0,0 +1,188 @@
+# `runtime.DedicateOSThread`
+
+Status as of 2020-09-18: Deprioritized; initial studies in #2180 suggest that
+this may be difficult to support in the Go runtime due to issues with GC.
+
+## Summary
+
+Allow goroutines to bind to kernel threads in a way that allows their scheduling
+to be kernel-managed rather than runtime-managed.
+
+## Objectives
+
+* Reduce Go runtime overhead in the gVisor sentry (#2184).
+
+* Minimize intrusiveness of changes to the Go runtime.
+
+## Background
+
+In Go, execution contexts are referred to as goroutines, which the runtime calls
+Gs. The Go runtime maintains a variably-sized pool of threads (called Ms by the
+runtime) on which Gs are executed, as well as a pool of "virtual processors"
+(called Ps by the runtime) of size equal to `runtime.GOMAXPROCS()`. Usually,
+each M requires a P in order to execute Gs, limiting the number of concurrently
+executing goroutines to `runtime.GOMAXPROCS()`.
+
+The `runtime.LockOSThread` function temporarily locks the invoking goroutine to
+its current thread. It is primarily useful for interacting with OS or non-Go
+library facilities that are per-thread. It does not reduce interactions with the
+Go runtime scheduler: locked Ms relinquish their P when they become blocked, and
+only continue execution after another M "chooses" their locked G to run and
+donates their P to the locked M instead.
+
+## Problems
+
+### Context Switch Overhead
+
+Most goroutines in the gVisor sentry are task goroutines, which back application
+threads. Task goroutines spend large amounts of time blocked on syscalls that
+execute untrusted application code. When invoking said syscall (which varies by
+gVisor platform), the task goroutine may interact with the Go runtime in one of
+three ways:
+
+* It can invoke the syscall without informing the runtime. In this case, the
+ task goroutine will continue to hold its P during the syscall, limiting the
+ number of application threads that can run concurrently to
+ `runtime.GOMAXPROCS()`. This is problematic because the Go runtime scheduler
+ is known to scale poorly with `GOMAXPROCS`; see #1942 and
+ https://github.com/golang/go/issues/28808. It also means that preemption of
+ application threads must be driven by sentry or runtime code, which is
+ strictly slower than kernel-driven preemption (since the sentry must invoke
+ another syscall to preempt the application thread).
+
+* It can call `runtime.entersyscallblock` before invoking the syscall, and
+ `runtime.exitsyscall` after the syscall returns. In this case, the task
+ goroutine will release its P while the syscall is executing. This allows the
+ number of threads concurrently executing application code to exceed
+ `GOMAXPROCS`. However, this incurs additional latency on syscall entry (to
+ hand off the released P to another M, often requiring a `futex(FUTEX_WAKE)`
+ syscall) and on syscall exit (to acquire a new P). It also drastically
+ increases the number of threads that concurrently interact with the runtime
+ scheduler, which is also problematic for performance (both in terms of CPU
+ utilization and in terms of context switch latency); see #205.
+
+- It can call `runtime.entersyscall` before invoking the syscall, and
+ `runtime.exitsyscall` after the syscall returns. In this case, the task
+ goroutine "lazily releases" its P, allowing the runtime's "sysmon" thread to
+ steal it on behalf of another M after a 20us delay. This mitigates the
+ context switch latency problem when there are few task goroutines and the
+ interval between switches to application code (i.e. the interval between
+ application syscalls, page faults, or signal delivery) is short. (Cynically,
+ this means that it's most effective in microbenchmarks). However, the delay
+ before a P is stolen can also be problematic for performance when there are
+ both many task goroutines switching to application code (lazily releasing
+ their Ps) *and* many task goroutines switching to sentry code (contending
+ for Ps), which is likely in larger heterogeneous workloads.
+
+### Blocking Overhead
+
+Task goroutines block on behalf of application syscalls like `futex` and
+`epoll_wait` by receiving from a Go channel. (Future work may convert task
+goroutine blocking to use the `syncevent` package to avoid overhead associated
+with channels and `select`, but this does not change how blocking interacts with
+the Go runtime scheduler.)
+
+If `runtime.LockOSThread()` is not in effect when a task goroutine blocks, then
+when the task goroutine is unblocked (by e.g. an application `FUTEX_WAKE`,
+signal delivery, or a timeout) by sending to the blocked channel,
+`runtime.ready` migrates the unblocked G to the unblocking P. In most cases,
+this implies that every application thread block/unblock cycle results in a
+migration of the thread between Ps, and therefore Ms, and therefore cores,
+resulting in reduced application performance due to loss of CPU caches.
+Furthermore, in most cases, the unblocking P cannot immediately switch to the
+unblocked G (instead resuming execution of its current application thread after
+completing the application's `futex(FUTEX_WAKE)`, `tgkill`, etc. syscall), often
+requiring that another P steal the unblocked G before it can resume execution.
+
+If `runtime.LockOSThread()` is in effect when a task goroutine blocks, then the
+G will remain locked to its M, avoiding the core migration described above;
+however, wakeup latency is significantly increased since, as described in
+"Background", the G still needs to be selected by the scheduler before it can
+run, and the M that selects the G then needs to transfer its P to the locked M,
+incurring an additional `FUTEX_WAKE` syscall and round of kernel scheduling.
+
+## Proposal
+
+We propose to add a function, tentatively called `DedicateOSThread`, to the Go
+`runtime` package, documented as follows:
+
+```go
+// DedicateOSThread wires the calling goroutine to its current operating system
+// thread, and exempts it from counting against GOMAXPROCS. The calling
+// goroutine will always execute in that thread, and no other goroutine will
+// execute in it, until the calling goroutine has made as many calls to
+// UndedicateOSThread as to DedicateOSThread. If the calling goroutine exits
+// without unlocking the thread, the thread will be terminated.
+//
+// DedicateOSThread should only be used by long-lived goroutines that usually
+// block due to blocking system calls, rather than interaction with other
+// goroutines.
+func DedicateOSThread()
+```
+
+Mechanically, `DedicateOSThread` implies `LockOSThread` (i.e. it locks the
+invoking G to a M), but additionally locks the invoking M to a P. Ps locked by
+`DedicateOSThread` are not counted against `GOMAXPROCS`; that is, the actual
+number of Ps in the system (`len(runtime.allp)`) is `GOMAXPROCS` plus the number
+of bound Ps (plus some slack to avoid frequent changes to `runtime.allp`).
+Corollaries:
+
+* If `runtime.ready` observes that a readied G is locked to a M locked to a P,
+ it immediately wakes the locked M without migrating the G to the readying P
+ or waiting for a future call to `runtime.schedule` to select the readied G
+ in `runtime.findrunnable`.
+
+* `runtime.stoplockedm` and `runtime.reentersyscall` skip the release of
+ locked Ps; the latter also skips sysmon wakeup. `runtime.stoplockedm` and
+ `runtime.exitsyscall` skip re-acquisition of Ps if one is locked.
+
+* sysmon does not attempt to preempt Gs that are locked to Ps, avoiding
+ fruitless overhead from `tgkill` syscalls and signal delivery.
+
+* `runtime.findrunnable`'s work stealing skips locked Ps (suggesting that
+ unlocked Ps be tracked in a separate array). `runtime.findrunnable` on
+ locked Ps skip the global run queue, work stealing, and possibly netpoll.
+
+* New goroutines created by goroutines with locked Ps are enqueued on the
+ global run queue rather than the invoking P's local run queue.
+
+While gVisor's use case does not strictly require that the association is
+reversible (with `runtime.UndedicateOSThread`), such a feature is required to
+allow reuse of locked Ms, which is likely to be critical for performance.
+
+## Alternatives Considered
+
+* Make the runtime scale well with `GOMAXPROCS`. While we are also
+ concurrently investigating this problem, this would not address the issues
+ of increased preemption cost or blocking overhead.
+
+* Make the runtime scale well with number of Ms. It is unclear if this is
+ actually feasible, and would not address blocking overhead.
+
+* Make P-locking part of `LockOSThread`'s behavior. This would likely
+ introduce performance regressions in existing uses of `LockOSThread` that do
+ not fit this usage pattern. In particular, since `DedicateOSThread`
+ transitions the invoker's P from "counted against `GOMAXPROCS`" to "not
+ counted against `GOMAXPROCS`", it may need to wake another M to run a new P
+ (that is counted against `GOMAXPROCS`), and the converse applies to
+ `UndedicateOSThread`.
+
+* Rewrite the gVisor sentry in a language that does not force userspace
+ scheduling. This is a last resort due to the amount of code involved.
+
+## Related Issues
+
+The proposed functionality is directly analogous to `spawn_blocking` in Rust
+async runtimes
+[`async_std`](https://docs.rs/async-std/1.8.0/async_std/task/fn.spawn_blocking.html)
+and [`tokio`](https://docs.rs/tokio/0.3.5/tokio/task/fn.spawn_blocking.html).
+
+Outside of gVisor:
+
+* https://github.com/golang/go/issues/21827#issuecomment-595152452 describes a
+ use case for this feature in go-delve, where the goroutine that would use
+ this feature spends much of its time blocked in `ptrace` syscalls.
+
+* This feature may improve performance in the use case described in
+ https://github.com/golang/go/issues/18237, given the prominence of
+ syscall.Syscall in the profile given in that bug report.
diff --git a/go.mod b/go.mod
index 144543169..823c3596d 100644
--- a/go.mod
+++ b/go.mod
@@ -10,8 +10,8 @@ require (
github.com/Microsoft/hcsshim v0.8.6 // indirect
github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect
github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3 // indirect
- github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect
- github.com/containerd/containerd v1.3.4 // indirect
+ github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327
+ github.com/containerd/containerd v1.3.9 // indirect
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect
github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect
diff --git a/go.sum b/go.sum
index 060d5596a..70514ea14 100644
--- a/go.sum
+++ b/go.sum
@@ -51,12 +51,12 @@ github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssN
github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI=
github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59 h1:qWj4qVYZ95vLWwqyNJCQg7rDsG5wPdze0UaPolH7DUk=
github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59/go.mod h1:pA0z1pT8KYB3TCXK/ocprsh7MAkoW8bZVzPdih9snmM=
+github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw=
+github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw=
github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc=
github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE=
-github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
-github.com/containerd/containerd v1.3.4 h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=
-github.com/containerd/containerd v1.3.4/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
+github.com/containerd/containerd v1.3.9/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y=
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM=
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a/go.mod h1:W0qIOTD7mp2He++YVq+kgfXezRYqzP1uDuMVH1bITDY=
diff --git a/images/BUILD b/images/BUILD
index a50f388e9..34b950644 100644
--- a/images/BUILD
+++ b/images/BUILD
@@ -1,11 +1 @@
package(licenses = ["notice"])
-
-# The images filegroup is definitely not a hermetic target, and requires Make
-# to do anything meaningful with. However, this will be slurped up and used by
-# the tools/installer/images.sh installer, which will ensure that all required
-# images are available locally when running vm_tests.
-filegroup(
- name = "images",
- srcs = glob(["**"]),
- visibility = ["//tools/installers:__pkg__"],
-)
diff --git a/images/Makefile b/images/Makefile
deleted file mode 100644
index 66aac7802..000000000
--- a/images/Makefile
+++ /dev/null
@@ -1,107 +0,0 @@
-#!/usr/bin/make -f
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# ARCH is the architecture used for the build. This may be overriden at the
-# command line in order to perform a cross-build (in a limited capacity).
-ARCH := $(shell uname -m)
-
-# Note that the image prefixes used here must match the image mangling in
-# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all
-# tests are using locally-defined images (that are consistent and idempotent).
-REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit
-LOCAL_IMAGE_PREFIX ?= gvisor.dev/images
-ALL_IMAGES := $(subst /,_,$(subst ./,,$(shell find . -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq)))
-ifneq ($(ARCH),$(shell uname -m))
-DOCKER_PLATFORM_ARGS := --platform=$(ARCH)
-else
-DOCKER_PLATFORM_ARGS :=
-endif
-
-list-all-images:
- @for image in $(ALL_IMAGES); do echo $${image}; done
-.PHONY: list-build-images
-
-# Handy wrapper to allow load-all-images, push-all-images, etc.
-%-all-images:
- @$(MAKE) -s $(patsubst %,$*-%,$(ALL_IMAGES))
-load-all-images:
- @$(MAKE) -s $(patsubst %,load-%,$(ALL_IMAGES))
-
-# Handy wrapper to load specified "groups", e.g. load-basic-images, etc.
-load-%-images:
- @$(MAKE) -s $(patsubst %,load-%,$(subst /,_,$(subst ./,,$(shell find ./$* -name Dockerfile -exec dirname {} \;))))
-
-# tag is a function that returns the tag name, given an image.
-#
-# The tag constructed is used to memoize the image generated (see README.md).
-# This scheme is used to enable aggressive caching in a central repository, but
-# ensuring that images will always be sourced using the local files if there
-# are changes.
-path = $(subst _,/,$(1))
-dockerfile = $$(if [ -f "$(call path,$(1))/Dockerfile.$(ARCH)" ]; then echo Dockerfile.$(ARCH); else echo Dockerfile; fi)
-tag = $(shell find $(call path,$(1)) -type f -print | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16)
-remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH):$(call tag,$(1))
-local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1))
-
-# rebuild builds the image locally. Only the "remote" tag will be applied. Note
-# we need to explicitly repull the base layer in order to ensure that the
-# architecture is correct. Note that we use the term "rebuild" here to avoid
-# conflicting with the bazel "build" terminology, which is used elsewhere.
-rebuild-%: FROM=$(shell grep FROM "$(call path,$*)/$(call dockerfile,$*)" | cut -d' ' -f2)
-rebuild-%: register-cross
- @if ! [ -f "$(call path,$*)/$(call dockerfile,$*)" ]; then \
- (echo "ERROR: Dockerfile for $* not found (is it available for $(ARCH)?)." >&2 && exit 1); \
- fi
- $(foreach IMAGE,$(FROM),docker pull $(DOCKER_PLATFORM_ARGS) $(IMAGE) &&) \
- T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \
- docker build $(DOCKER_PLATFORM_ARGS) \
- -f "$$T/$(call dockerfile,$*)" \
- -t "$(call remote_image,$*)" \
- $$T && \
- rm -rf $$T
-
-# pull will check the "remote" image and pull if necessary. If the remote image
-# must be pulled, then it will tag with the latest local target. Note that pull
-# may fail if the remote image is not available.
-pull-%:
- docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*)
-
-# load will either pull the "remote" or build it locally. This is the preferred
-# entrypoint, as it should never fail. The local tag should always be set after
-# this returns (either by the pull or the build).
-load-%:
- $(MAKE) -s pull-$* || $(MAKE) -s rebuild-$*
- docker tag $(call remote_image,$*) $(call local_image,$*)
-
-# push pushes the remote image, after either pulling (to validate that the tag
-# already exists) or building manually.
-push-%: load-%
- docker push $(call remote_image,$*)
-
-# register-cross registers the necessary qemu binaries for cross-compilation.
-# This may be used by any target that may execute containers that are not the
-# native format.
-register-cross:
-ifneq ($(ARCH),$(shell uname -m))
-ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*))
- docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes
-else
- @true # Already registered.
-endif
-else
- @true # No cross required.
-endif
-.PHONY: register-cross
diff --git a/images/agent/Dockerfile b/images/agent/Dockerfile
new file mode 100644
index 000000000..1d8979390
--- /dev/null
+++ b/images/agent/Dockerfile
@@ -0,0 +1,12 @@
+FROM golang:1.15 as build-agent
+RUN git clone --depth=1 --branch=v3.25.0 https://github.com/buildkite/agent
+RUN cd agent && go build -i -o /buildkite-agent .
+
+FROM golang:1.15 as build-agent-metrics
+RUN git clone --depth=1 --branch=v5.2.0 https://github.com/buildkite/buildkite-agent-metrics
+RUN cd buildkite-agent-metrics && go build -i -o /buildkite-agent-metrics .
+
+FROM gcr.io/distroless/base-debian10
+COPY --from=build-agent /buildkite-agent /
+COPY --from=build-agent-metrics /buildkite-agent-metrics /
+CMD ["/buildkite-agent"]
diff --git a/images/agent/README.md b/images/agent/README.md
new file mode 100644
index 000000000..acb57bd2f
--- /dev/null
+++ b/images/agent/README.md
@@ -0,0 +1,7 @@
+# Build Agent
+
+This is the image used by the build agent. It is built and bundled via a
+separate packaging mechanism in order to provide local caching and to ensure
+that there is better build provenance. Note that continuous integration system
+will generally deploy new agents from the primary branch, and will only deploy
+as instances are recycled. Updates to this image should be made carefully.
diff --git a/images/arm-qemu/Dockerfile.x86_64 b/images/arm-qemu/Dockerfile.x86_64
new file mode 100644
index 000000000..1a2ecaf42
--- /dev/null
+++ b/images/arm-qemu/Dockerfile.x86_64
@@ -0,0 +1,12 @@
+FROM fedora:33
+
+RUN dnf install -y qemu-system-aarch64 gzip cpio wget
+
+WORKDIR /workdir
+RUN wget -4 http://dl-cdn.alpinelinux.org/alpine/edge/releases/aarch64/netboot/vmlinuz-lts
+RUN wget -4 http://dl-cdn.alpinelinux.org/alpine/edge/releases/aarch64/netboot/initramfs-lts
+
+COPY initramfs /workdir/initramfs
+COPY test.sh /workdir/
+
+CMD ./test.sh
diff --git a/images/arm-qemu/initramfs/init b/images/arm-qemu/initramfs/init
new file mode 100755
index 000000000..b355daadd
--- /dev/null
+++ b/images/arm-qemu/initramfs/init
@@ -0,0 +1,39 @@
+#!/bin/sh
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script is started as the init process in a test virtual machine,
+# it does all required initialization steps and run a test command inside a
+# gVisor instance.
+
+set -x -e
+
+/bin/busybox mkdir -p /usr/bin /usr/sbin /proc /sys /dev /tmp
+
+/bin/busybox --install -s
+export PATH=/usr/bin:/bin:/usr/sbin:/sbin
+
+mount -t proc -o noexec,nosuid,nodev proc /proc
+mount -t sysfs -o noexec,nosuid,nodev sysfs /sys
+mount -t devtmpfs -o exec,nosuid,mode=0755,size=2M devtmpfs /dev
+
+uname -a
+/runsc --TESTONLY-unsafe-nonroot --rootless --network none --debug --alsologtostderr do uname -a
+echo "runsc exited with code $?"
+
+# Shutdown the VM. poweroff and halt doesn't work for unknown reasons.
+# qemu is started with the -no-reboot flag, so the VM will be terminated.
+reboot -f
+exit 1
diff --git a/tools/vm/zone.sh b/images/arm-qemu/test.sh
index 79569fb19..2c9336015 100755
--- a/tools/vm/zone.sh
+++ b/images/arm-qemu/test.sh
@@ -14,4 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-exec gcloud config get-value compute/zone
+set -xeuo pipefail -m
+
+cd initramfs
+find . | cpio -v -o -c -R root:root | gzip -9 >> ../initramfs-lts
+cd ..
+
+qemu-system-aarch64 -M virt -m 512M -cpu cortex-a57 \
+ -kernel vmlinuz-lts -initrd initramfs-lts \
+ -append "console=ttyAMA0 panic=-1" -nographic -no-reboot \
+ | tee /dev/stderr | grep "runsc exited with code 0"
+
+echo "PASS"
diff --git a/images/basic/ping4test/Dockerfile b/images/basic/ping4test/Dockerfile
new file mode 100644
index 000000000..1536be376
--- /dev/null
+++ b/images/basic/ping4test/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:bionic
+
+WORKDIR /root
+COPY ping4.sh .
+RUN chmod +x ping4.sh
+
+RUN apt-get update && apt-get install -y iputils-ping
diff --git a/images/basic/ping4test/ping4.sh b/images/basic/ping4test/ping4.sh
new file mode 100644
index 000000000..2a343712a
--- /dev/null
+++ b/images/basic/ping4test/ping4.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euo pipefail
+
+# The docker API doesn't provide for starting a container, running a command,
+# and getting the exit status of the command in one go. The most straightforward
+# way to do this is to verify the output of the command, so we output nothing on
+# success and an error message on failure.
+if ! out=$(ping -c 10 127.0.0.1); then
+ echo "$out"
+fi
diff --git a/images/basic/ping6test/Dockerfile b/images/basic/ping6test/Dockerfile
new file mode 100644
index 000000000..cb740bd60
--- /dev/null
+++ b/images/basic/ping6test/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:bionic
+
+WORKDIR /root
+COPY ping6.sh .
+RUN chmod +x ping6.sh
+
+RUN apt-get update && apt-get install -y iputils-ping iproute2
diff --git a/images/basic/ping6test/ping6.sh b/images/basic/ping6test/ping6.sh
new file mode 100644
index 000000000..4268951d0
--- /dev/null
+++ b/images/basic/ping6test/ping6.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euo pipefail
+
+# Enable ipv6 on loopback if it's not already enabled. Runsc doesn't enable ipv6
+# loopback unless an ipv6 address was assigned to the container, which docker
+# does not do by default.
+if ! [[ $(ip -6 addr show dev lo) ]]; then
+ ip addr add ::1 dev lo
+fi
+
+# The docker API doesn't provide for starting a container, running a command,
+# and getting the exit status of the command in one go. The most straightforward
+# way to do this is to verify the output of the command, so we output nothing on
+# success and an error message on failure.
+if ! out=$(/bin/ping6 -c 10 ::1); then
+ echo "$out"
+fi
diff --git a/images/benchmarks/absl/Dockerfile b/images/benchmarks/absl/Dockerfile.x86_64
index b0dd97695..810c9ef5e 100644
--- a/images/benchmarks/absl/Dockerfile
+++ b/images/benchmarks/absl/Dockerfile.x86_64
@@ -12,6 +12,7 @@ RUN set -x \
unzip \
python3 \
&& rm -rf /var/lib/apt/lists/*
+
RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh
RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh
RUN ./bazel-0.27.0-installer-linux-x86_64.sh
diff --git a/images/benchmarks/hey/Dockerfile b/images/benchmarks/hey/Dockerfile
index f586978b6..4b6a0f849 100644
--- a/images/benchmarks/hey/Dockerfile
+++ b/images/benchmarks/hey/Dockerfile
@@ -1,12 +1,13 @@
-FROM ubuntu:18.04
+FROM golang:1.15 as build
+RUN go get github.com/rakyll/hey
+WORKDIR /go/src/github.com/rakyll/hey
+RUN go mod download
+RUN CGO_ENABLED=0 go build -o /hey hey.go
+FROM ubuntu:18.04
RUN set -x \
&& apt-get update \
&& apt-get install -y \
wget \
&& rm -rf /var/lib/apt/lists/*
-
-RUN wget https://storage.googleapis.com/hey-release/hey_linux_amd64 \
- && chmod 777 hey_linux_amd64 \
- && cp hey_linux_amd64 /bin/hey \
- && rm hey_linux_amd64
+COPY --from=build /hey /bin/hey
diff --git a/images/benchmarks/runsc/Dockerfile b/images/benchmarks/runsc/Dockerfile.x86_64
index 6c3aafa57..28ae64816 100644
--- a/images/benchmarks/runsc/Dockerfile
+++ b/images/benchmarks/runsc/Dockerfile.x86_64
@@ -14,6 +14,7 @@ RUN set -x \
python3 \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
+
RUN wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-installer-linux-x86_64.sh
RUN chmod +x bazel-3.4.1-installer-linux-x86_64.sh
RUN ./bazel-3.4.1-installer-linux-x86_64.sh
diff --git a/images/default/Dockerfile b/images/default/Dockerfile
index d058b83cb..19b340237 100644
--- a/images/default/Dockerfile
+++ b/images/default/Dockerfile
@@ -1,16 +1,29 @@
-FROM fedora:31
-# Install bazel.
-RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
-RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch diffutils
-RUN pip install --no-cache-dir pycparser
-RUN dnf install -y bazel3
+FROM ubuntu:focal
+
+ENV DEBIAN_FRONTEND="noninteractive"
+RUN apt-get update && apt-get install -y curl gnupg2 git \
+ python python3 python3-distutils python3-pip \
+ build-essential crossbuild-essential-arm64 qemu-user-static \
+ openjdk-11-jdk-headless zip unzip \
+ apt-transport-https ca-certificates gnupg-agent \
+ software-properties-common \
+ pkg-config libffi-dev patch diffutils libssl-dev
+
+# Install Docker client for the website build.
+RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
+RUN add-apt-repository \
+ "deb https://download.docker.com/linux/ubuntu \
+ $(lsb_release -cs) \
+ stable"
+RUN apt-get install docker-ce-cli
+
# Install gcloud.
RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \
- tar zxvf - google-cloud-sdk && \
- google-cloud-sdk/install.sh && \
+ tar zxf - google-cloud-sdk && \
+ google-cloud-sdk/install.sh --quiet && \
ln -s /google-cloud-sdk/bin/gcloud /usr/bin/gcloud
-# Install Docker client for the website build.
-RUN dnf config-manager --add-repo https://download.docker.com/linux/fedora/docker-ce.repo
-RUN dnf install -y docker-ce-cli
+
+# Download the official bazel binary. The APT repository isn't used because there is not packages for arm64.
+RUN sh -c 'curl -o /usr/local/bin/bazel https://releases.bazel.build/3.5.1/release/bazel-3.5.1-linux-$(uname -m | sed s/aarch64/arm64/) && chmod ugo+x /usr/local/bin/bazel'
WORKDIR /workspace
-ENTRYPOINT ["/usr/bin/bazel"]
+ENTRYPOINT ["/usr/local/bin/bazel"]
diff --git a/images/runtimes/go1.12/Dockerfile b/images/runtimes/go1.12/Dockerfile.x86_64
index cb2944062..cb2944062 100644
--- a/images/runtimes/go1.12/Dockerfile
+++ b/images/runtimes/go1.12/Dockerfile.x86_64
diff --git a/nogo.yaml b/nogo.yaml
index 7a5edc305..0a5ca78dc 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -56,123 +56,8 @@ global:
- "should not use ALL_CAPS in Go names"
- "should not use underscores in Go names"
exclude:
- # A variety of staticcheck and stylecheck
- # rules apply here. These should be fixed
- # and removed from here, and the global
- # rules should be used sparingly.
- - pkg/abi/linux/fuse.go:22
- - pkg/abi/linux/fuse.go:25
- - pkg/abi/linux/socket.go:113
- - pkg/abi/linux/tty.go:73
- - pkg/cpuid/cpuid_x86.go:675
- - pkg/gohacks/gohacks_unsafe.go:33
- - pkg/log/json.go:30
- - pkg/log/log.go:359
- - pkg/metric/metric_test.go:20
- - pkg/p9/p9test/client_test.go:687
- - pkg/p9/transport_test.go:196
- - pkg/pool/pool.go:15
- - pkg/refs/refcounter.go:510
- - pkg/refs/refcounter_test.go:169
- - pkg/safemem/block_unsafe.go:89
- - pkg/seccomp/seccomp.go:82
- - pkg/segment/test/set_functions.go:15
- - pkg/sentry/arch/signal.go:166
- - pkg/sentry/arch/signal.go:171
- - pkg/sentry/control/pprof.go:196
- - pkg/sentry/devices/memdev/full.go:58
- - pkg/sentry/devices/memdev/null.go:59
- - pkg/sentry/devices/memdev/random.go:68
- - pkg/sentry/devices/memdev/zero.go:86
- - pkg/sentry/fdimport/fdimport.go:15
- - pkg/sentry/fs/attr.go:257
- - pkg/sentry/fsbridge/fs.go:116
- - pkg/sentry/fsbridge/vfs.go:124
- - pkg/sentry/fsbridge/vfs.go:70
- - pkg/sentry/fs/copy_up.go:365
- - pkg/sentry/fs/copy_up_test.go:65
- - pkg/sentry/fs/dev/net_tun.go:161
- - pkg/sentry/fs/dev/net_tun.go:63
- - pkg/sentry/fs/dev/null.go:97
- - pkg/sentry/fs/dirent_cache.go:64
- - pkg/sentry/fs/fdpipe/pipe_opener_test.go:366
- - pkg/sentry/fs/file_overlay.go:327
- - pkg/sentry/fs/file_overlay.go:524
- - pkg/sentry/fs/filetest/filetest.go:55
- - pkg/sentry/fs/filetest/filetest.go:60
- - pkg/sentry/fs/fs.go:77
- - pkg/sentry/fs/fsutil/file.go:290
- - pkg/sentry/fs/fsutil/file.go:346
- - pkg/sentry/fs/fsutil/host_file_mapper.go:105
- - pkg/sentry/fs/fsutil/inode_cached.go:676
- - pkg/sentry/fs/fsutil/inode_cached.go:772
- - pkg/sentry/fs/gofer/attr.go:120
- - pkg/sentry/fs/gofer/fifo.go:33
- - pkg/sentry/fs/gofer/inode.go:410
- - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97
- - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92
- - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44
- - pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91
- - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93
- - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66
- - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53
- - pkg/sentry/fsimpl/fuse/request_response.go:71
- - pkg/sentry/fsimpl/signalfd/signalfd.go:15
- - pkg/sentry/memmap/memmap.go:103
- - pkg/sentry/memmap/memmap.go:163
- - pkg/sentry/mm/aio_context.go:208
- - pkg/sentry/mm/pma.go:683
- - pkg/sentry/usage/cpu.go:42
- - pkg/shim/runsc/runsc.go:16
- - pkg/shim/runsc/utils.go:16
- - pkg/shim/v1/proc/deleted_state.go:16
- - pkg/shim/v1/proc/exec.go:16
- - pkg/shim/v1/proc/exec_state.go:16
- - pkg/shim/v1/proc/init.go:16
- - pkg/shim/v1/proc/init_state.go:16
- - pkg/shim/v1/proc/io.go:16
- - pkg/shim/v1/proc/process.go:16
- - pkg/shim/v1/proc/types.go:16
- - pkg/shim/v1/proc/utils.go:16
- - pkg/shim/v1/shim/api.go:16
- - pkg/shim/v1/shim/platform.go:16
- - pkg/shim/v1/shim/service.go:16
- - pkg/shim/v1/utils/annotations.go:15
- - pkg/shim/v1/utils/utils.go:15
- - pkg/shim/v1/utils/volumes.go:15
- - pkg/shim/v2/api.go:16
- - pkg/shim/v2/epoll.go:18
- - pkg/shim/v2/options/options.go:15
- - pkg/shim/v2/options/options.go:24
- - pkg/shim/v2/options/options.go:26
- - pkg/shim/v2/runtimeoptions/runtimeoptions.go:16
- - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go # Generated: exempt all.
- - pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22
- - pkg/shim/v2/service.go:15
- - pkg/shim/v2/service_linux.go:18
- - pkg/state/tests/integer_test.go:23
- - pkg/state/tests/integer_test.go:28
- - pkg/sync/rwmutex_test.go:105
- - pkg/syserr/host_linux.go:35
- - pkg/usermem/addr.go:34
- - pkg/usermem/usermem.go:171
- - pkg/usermem/usermem.go:170
- - runsc/boot/compat.go:56
- - test/cmd/test_app/fds.go:171
- - test/iptables/filter_output.go:251
- - test/packetimpact/testbench/connections.go:77
- - tools/bigquery/bigquery.go:106
- - tools/checkescape/test1/test1.go:108
- - tools/checkescape/test1/test1.go:122
- - tools/checkescape/test1/test1.go:137
- - tools/checkescape/test1/test1.go:151
- - tools/checkescape/test1/test1.go:170
- - tools/checkescape/test1/test1.go:39
- - tools/checkescape/test1/test1.go:45
- - tools/checkescape/test1/test1.go:50
- - tools/checkescape/test1/test1.go:64
- - tools/checkescape/test1/test1.go:80
- - tools/checkescape/test1/test1.go:94
+ # Generated: exempt all.
+ - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go
analyzers:
asmdecl:
external: # Enabled.
@@ -214,6 +99,8 @@ analyzers:
printf:
external: # Enabled.
shift:
+ generated: # Disabled for generated code; these shifts are well-defined.
+ exclude: [".*"]
external: # Enabled.
stringintconv:
external:
@@ -250,3 +137,22 @@ analyzers:
external: # Enabled.
checkescape:
external: # Enabled.
+ SA4016:
+ internal:
+ exclude:
+ - pkg/gohacks/gohacks_unsafe.go # x ^ 0 always equals x.
+ SA2001:
+ internal:
+ exclude:
+ - pkg/sentry/fs/fs.go # Intentional.
+ - pkg/sentry/fs/gofer/inode.go # Intentional.
+ - pkg/refs/refcounter_test.go # Intentional.
+ ST1021:
+ internal:
+ suppress:
+ - "comment on exported type Translation" # Intentional.
+ - "comment on exported type PinnedRange" # Intentional.
+ SA5011:
+ internal:
+ exclude:
+ - pkg/sentry/fs/fdpipe/pipe_opener_test.go # False positive.
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index a0654df2f..8fa61d6f7 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -21,6 +21,7 @@ go_library(
"epoll_amd64.go",
"epoll_arm64.go",
"errors.go",
+ "errqueue.go",
"eventfd.go",
"exec.go",
"fadvise.go",
diff --git a/pkg/abi/linux/errqueue.go b/pkg/abi/linux/errqueue.go
new file mode 100644
index 000000000..3905d4222
--- /dev/null
+++ b/pkg/abi/linux/errqueue.go
@@ -0,0 +1,93 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
+// Socket error origin codes as defined in include/uapi/linux/errqueue.h.
+const (
+ SO_EE_ORIGIN_NONE = 0
+ SO_EE_ORIGIN_LOCAL = 1
+ SO_EE_ORIGIN_ICMP = 2
+ SO_EE_ORIGIN_ICMP6 = 3
+)
+
+// SockExtendedErr represents struct sock_extended_err in Linux defined in
+// include/uapi/linux/errqueue.h.
+//
+// +marshal
+type SockExtendedErr struct {
+ Errno uint32
+ Origin uint8
+ Type uint8
+ Code uint8
+ Pad uint8
+ Info uint32
+ Data uint32
+}
+
+// SockErrCMsg represents the IP*_RECVERR control message.
+type SockErrCMsg interface {
+ marshal.Marshallable
+
+ CMsgLevel() uint32
+ CMsgType() uint32
+}
+
+// SockErrCMsgIPv4 is the IP_RECVERR control message used in
+// recvmsg(MSG_ERRQUEUE) by ipv4 sockets. This is equilavent to `struct errhdr`
+// defined in net/ipv4/ip_sockglue.c:ip_recv_error().
+//
+// +marshal
+type SockErrCMsgIPv4 struct {
+ SockExtendedErr
+ Offender SockAddrInet
+}
+
+var _ SockErrCMsg = (*SockErrCMsgIPv4)(nil)
+
+// CMsgLevel implements SockErrCMsg.CMsgLevel.
+func (*SockErrCMsgIPv4) CMsgLevel() uint32 {
+ return SOL_IP
+}
+
+// CMsgType implements SockErrCMsg.CMsgType.
+func (*SockErrCMsgIPv4) CMsgType() uint32 {
+ return IP_RECVERR
+}
+
+// SockErrCMsgIPv6 is the IPV6_RECVERR control message used in
+// recvmsg(MSG_ERRQUEUE) by ipv6 sockets. This is equilavent to `struct errhdr`
+// defined in net/ipv6/datagram.c:ipv6_recv_error().
+//
+// +marshal
+type SockErrCMsgIPv6 struct {
+ SockExtendedErr
+ Offender SockAddrInet6
+}
+
+var _ SockErrCMsg = (*SockErrCMsgIPv6)(nil)
+
+// CMsgLevel implements SockErrCMsg.CMsgLevel.
+func (*SockErrCMsgIPv6) CMsgLevel() uint32 {
+ return SOL_IPV6
+}
+
+// CMsgType implements SockErrCMsg.CMsgType.
+func (*SockErrCMsgIPv6) CMsgType() uint32 {
+ return IPV6_RECVERR
+}
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index cc3571fad..d1ca56370 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -25,6 +25,8 @@ const (
F_SETLKW = 7
F_SETOWN = 8
F_GETOWN = 9
+ F_SETSIG = 10
+ F_GETSIG = 11
F_SETOWN_EX = 15
F_GETOWN_EX = 16
F_DUPFD_CLOEXEC = 1024 + 6
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
index d91c97a64..1070b457c 100644
--- a/pkg/abi/linux/fuse.go
+++ b/pkg/abi/linux/fuse.go
@@ -19,16 +19,22 @@ import (
"gvisor.dev/gvisor/pkg/marshal/primitive"
)
+// FUSEOpcode is a FUSE operation code.
+//
// +marshal
type FUSEOpcode uint32
+// FUSEOpID is a FUSE operation ID.
+//
// +marshal
type FUSEOpID uint64
// FUSE_ROOT_ID is the id of root inode.
const FUSE_ROOT_ID = 1
-// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
+// Opcodes for FUSE operations.
+//
+// Analogous to the opcodes in include/linux/fuse.h.
const (
FUSE_LOOKUP FUSEOpcode = 1
FUSE_FORGET = 2 /* no reply */
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index 1b2f76c0b..2424884c1 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -32,6 +32,23 @@ const (
SEM_STAT_ANY = 20
)
+// Information about system-wide sempahore limits and parameters.
+//
+// Source: include/uapi/linux/sem.h
+const (
+ SEMMNI = 32000
+ SEMMSL = 32000
+ SEMMNS = SEMMNI * SEMMSL
+ SEMOPM = 500
+ SEMVMX = 32767
+ SEMAEM = SEMVMX
+
+ SEMUME = SEMOPM
+ SEMMNU = SEMMNS
+ SEMMAP = SEMMNS
+ SEMUSZ = 20
+)
+
const SEM_UNDO = 0x1000
// Sembuf is equivalent to struct sembuf.
@@ -42,3 +59,21 @@ type Sembuf struct {
SemOp int16
SemFlg int16
}
+
+// SemInfo is equivalent to struct seminfo.
+//
+// Source: include/uapi/linux/sem.h
+//
+// +marshal
+type SemInfo struct {
+ SemMap uint32
+ SemMni uint32
+ SemMns uint32
+ SemMnu uint32
+ SemMsl uint32
+ SemOpm uint32
+ SemUme uint32
+ SemUsz uint32
+ SemVmx uint32
+ SemAem uint32
+}
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index d156d41e4..8591acbf2 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -111,12 +111,12 @@ type SockType int
// Socket types, from linux/net.h.
const (
SOCK_STREAM SockType = 1
- SOCK_DGRAM = 2
- SOCK_RAW = 3
- SOCK_RDM = 4
- SOCK_SEQPACKET = 5
- SOCK_DCCP = 6
- SOCK_PACKET = 10
+ SOCK_DGRAM SockType = 2
+ SOCK_RAW SockType = 3
+ SOCK_RDM SockType = 4
+ SOCK_SEQPACKET SockType = 5
+ SOCK_DCCP SockType = 6
+ SOCK_PACKET SockType = 10
)
// SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are
@@ -250,6 +250,12 @@ type SockAddrInet struct {
_ [8]uint8 // pad to sizeof(struct sockaddr).
}
+// Inet6MulticastRequest is struct ipv6_mreq, from uapi/linux/in6.h.
+type Inet6MulticastRequest struct {
+ MulticastAddr Inet6Addr
+ InterfaceIndex int32
+}
+
// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
type InetMulticastRequest struct {
MulticastAddr InetAddr
@@ -448,6 +454,8 @@ type ControlMessageCredentials struct {
// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
//
// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+//
+// +stateify savable
type ControlMessageIPPacketInfo struct {
NIC int32
LocalAddr InetAddr
diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go
index 41abe1f2d..629dae8f4 100644
--- a/pkg/control/server/server.go
+++ b/pkg/control/server/server.go
@@ -67,9 +67,10 @@ func (s *Server) Wait() {
// and the server should not be used afterwards.
func (s *Server) Stop() {
s.socket.Close()
- s.wg.Wait()
+ s.Wait()
- // This will cause existing clients to be terminated safely.
+ // This will cause existing clients to be terminated safely. If the
+ // registered handlers have a Stop callback, it will be called.
s.server.Stop()
}
diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go
index a4f4b2c5e..fdfe31417 100644
--- a/pkg/coverage/coverage.go
+++ b/pkg/coverage/coverage.go
@@ -27,6 +27,7 @@ import (
"io"
"sort"
"sync/atomic"
+ "testing"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -34,12 +35,6 @@ import (
"github.com/bazelbuild/rules_go/go/tools/coverdata"
)
-// KcovAvailable returns whether the kcov coverage interface is available. It is
-// available as long as coverage is enabled for some files.
-func KcovAvailable() bool {
- return len(coverdata.Cover.Blocks) > 0
-}
-
// coverageMu must be held while accessing coverdata.Cover. This prevents
// concurrent reads/writes from multiple threads collecting coverage data.
var coverageMu sync.RWMutex
@@ -47,6 +42,22 @@ var coverageMu sync.RWMutex
// once ensures that globalData is only initialized once.
var once sync.Once
+// blockBitLength is the number of bits used to represent coverage block index
+// in a synthetic PC (the rest are used to represent the file index). Even
+// though a PC has 64 bits, we only use the lower 32 bits because some users
+// (e.g., syzkaller) may truncate that address to a 32-bit value.
+//
+// As of this writing, there are ~1200 files that can be instrumented and at
+// most ~1200 blocks per file, so 16 bits is more than enough to represent every
+// file and every block.
+const blockBitLength = 16
+
+// KcovAvailable returns whether the kcov coverage interface is available. It is
+// available as long as coverage is enabled for some files.
+func KcovAvailable() bool {
+ return len(coverdata.Cover.Blocks) > 0
+}
+
var globalData struct {
// files is the set of covered files sorted by filename. It is calculated at
// startup.
@@ -104,14 +115,14 @@ var coveragePool = sync.Pool{
// coverage tools, we reset the global coverage data every time this function is
// run.
func ConsumeCoverageData(w io.Writer) int {
- once.Do(initCoverageData)
+ InitCoverageData()
coverageMu.Lock()
defer coverageMu.Unlock()
total := 0
var pcBuffer [8]byte
- for fileIndex, file := range globalData.files {
+ for fileNum, file := range globalData.files {
counters := coverdata.Cover.Counters[file]
for index := 0; index < len(counters); index++ {
if atomic.LoadUint32(&counters[index]) == 0 {
@@ -119,7 +130,7 @@ func ConsumeCoverageData(w io.Writer) int {
}
// Non-zero coverage data found; consume it and report as a PC.
atomic.StoreUint32(&counters[index], 0)
- pc := globalData.syntheticPCs[fileIndex][index]
+ pc := globalData.syntheticPCs[fileNum][index]
usermem.ByteOrder.PutUint64(pcBuffer[:], pc)
n, err := w.Write(pcBuffer[:])
if err != nil {
@@ -142,31 +153,84 @@ func ConsumeCoverageData(w io.Writer) int {
return total
}
-// initCoverageData initializes globalData. It should only be called once,
-// before any kcov data is written.
-func initCoverageData() {
- // First, order all files. Then calculate synthetic PCs for every block
- // (using the well-defined ordering for files as well).
- for file := range coverdata.Cover.Blocks {
- globalData.files = append(globalData.files, file)
+// InitCoverageData initializes globalData. It should be called before any kcov
+// data is written.
+func InitCoverageData() {
+ once.Do(func() {
+ // First, order all files. Then calculate synthetic PCs for every block
+ // (using the well-defined ordering for files as well).
+ for file := range coverdata.Cover.Blocks {
+ globalData.files = append(globalData.files, file)
+ }
+ sort.Strings(globalData.files)
+
+ for fileNum, file := range globalData.files {
+ blocks := coverdata.Cover.Blocks[file]
+ pcs := make([]uint64, 0, len(blocks))
+ for blockNum := range blocks {
+ pcs = append(pcs, calculateSyntheticPC(fileNum, blockNum))
+ }
+ globalData.syntheticPCs = append(globalData.syntheticPCs, pcs)
+ }
+ })
+}
+
+// Symbolize prints information about the block corresponding to pc.
+func Symbolize(out io.Writer, pc uint64) error {
+ fileNum, blockNum := syntheticPCToIndexes(pc)
+ file, err := fileFromIndex(fileNum)
+ if err != nil {
+ return err
+ }
+ block, err := blockFromIndex(file, blockNum)
+ if err != nil {
+ return err
}
- sort.Strings(globalData.files)
-
- // nextSyntheticPC is the first PC that we generate for a block.
- //
- // This uses a standard-looking kernel range for simplicity.
- //
- // FIXME(b/160639712): This is only necessary because syzkaller requires
- // addresses in the kernel range. If we can remove this constraint, then we
- // should be able to use the actual addresses.
- var nextSyntheticPC uint64 = 0xffffffff80000000
- for _, file := range globalData.files {
- blocks := coverdata.Cover.Blocks[file]
- thisFile := make([]uint64, 0, len(blocks))
- for range blocks {
- thisFile = append(thisFile, nextSyntheticPC)
- nextSyntheticPC++ // Advance.
+ writeBlock(out, pc, file, block)
+ return nil
+}
+
+// WriteAllBlocks prints all information about all blocks along with their
+// corresponding synthetic PCs.
+func WriteAllBlocks(out io.Writer) {
+ for fileNum, file := range globalData.files {
+ for blockNum, block := range coverdata.Cover.Blocks[file] {
+ writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block)
}
- globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile)
}
}
+
+func calculateSyntheticPC(fileNum int, blockNum int) uint64 {
+ return (uint64(fileNum) << blockBitLength) + uint64(blockNum)
+}
+
+func syntheticPCToIndexes(pc uint64) (fileNum int, blockNum int) {
+ return int(pc >> blockBitLength), int(pc & ((1 << blockBitLength) - 1))
+}
+
+// fileFromIndex returns the name of the file in the sorted list of instrumented files.
+func fileFromIndex(i int) (string, error) {
+ total := len(globalData.files)
+ if i < 0 || i >= total {
+ return "", fmt.Errorf("file index out of range: [%d] with length %d", i, total)
+ }
+ return globalData.files[i], nil
+}
+
+// blockFromIndex returns the i-th block in the given file.
+func blockFromIndex(file string, i int) (testing.CoverBlock, error) {
+ blocks, ok := coverdata.Cover.Blocks[file]
+ if !ok {
+ return testing.CoverBlock{}, fmt.Errorf("instrumented file %s does not exist", file)
+ }
+ total := len(blocks)
+ if i < 0 || i >= total {
+ return testing.CoverBlock{}, fmt.Errorf("block index out of range: [%d] with length %d", i, total)
+ }
+ return blocks[i], nil
+}
+
+func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) {
+ io.WriteString(out, fmt.Sprintf("%#x\n", pc))
+ io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1))
+}
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index f7f9dbf86..69eeb7528 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -36,3 +36,14 @@ package cpuid
// On arm64, features are numbered according to the ELF HWCAP definition.
// arch/arm64/include/uapi/asm/hwcap.h
type Feature int
+
+// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
+// subset of the host feature set.
+type ErrIncompatible struct {
+ message string
+}
+
+// Error implements error.
+func (e ErrIncompatible) Error() string {
+ return e.message
+}
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 17a89c00d..392711e8f 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -681,17 +681,6 @@ func (fs *FeatureSet) Intel() bool {
return fs.VendorID == intelVendorID
}
-// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
-// subset of the host feature set.
-type ErrIncompatible struct {
- message string
-}
-
-// Error implements error.
-func (e ErrIncompatible) Error() string {
- return e.message
-}
-
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
func (fs *FeatureSet) CheckHostCompatible() error {
hfs := HostFeatureSet()
diff --git a/pkg/crypto/BUILD b/pkg/crypto/BUILD
new file mode 100644
index 000000000..08fa772ca
--- /dev/null
+++ b/pkg/crypto/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "crypto",
+ srcs = [
+ "crypto.go",
+ "crypto_stdlib.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/sleep/empty.s b/pkg/crypto/crypto.go
index fb37360ac..b26b55d37 100644
--- a/pkg/sleep/empty.s
+++ b/pkg/crypto/crypto.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,4 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Empty assembly file so empty func definitions work.
+// Package crypto wraps crypto primitives.
+package crypto
diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go
new file mode 100644
index 000000000..74a55a123
--- /dev/null
+++ b/pkg/crypto/crypto_stdlib.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package crypto
+
+import (
+ "crypto/ecdsa"
+ "crypto/sha512"
+ "math/big"
+)
+
+// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the
+// public key, pub. Its return value records whether the signature is valid.
+func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool {
+ return ecdsa.Verify(pub, hash, r, s)
+}
+
+// SumSha384 returns the SHA384 checksum of the data.
+func SumSha384(data []byte) (sum384 [sha512.Size384]byte) {
+ return sha512.Sum384(data)
+}
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index aa8e4e1f3..cc31d0175 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -11,7 +11,8 @@ go_library(
"futex_linux.go",
"io.go",
"packet_window_allocator.go",
- "packet_window_mmap.go",
+ "packet_window_mmap_amd64.go",
+ "packet_window_mmap_arm64.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap_amd64.go
index 869183b11..869183b11 100644
--- a/pkg/flipcall/packet_window_mmap.go
+++ b/pkg/flipcall/packet_window_mmap_amd64.go
diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/flipcall/packet_window_mmap_arm64.go
index 19d6b0b15..b9c9c44f6 100644
--- a/pkg/syncevent/waiter_asm_unsafe.go
+++ b/pkg/flipcall/packet_window_mmap_arm64.go
@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 arm64
+// +build arm64
-package syncevent
+package flipcall
import (
- "unsafe"
+ "syscall"
)
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/log/json.go b/pkg/log/json.go
index bdf9d691e..8c52dcc87 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -27,8 +27,8 @@ type jsonLog struct {
}
// MarshalJSON implements json.Marshaler.MarashalJSON.
-func (lv Level) MarshalJSON() ([]byte, error) {
- switch lv {
+func (l Level) MarshalJSON() ([]byte, error) {
+ switch l {
case Warning:
return []byte(`"warning"`), nil
case Info:
@@ -36,20 +36,20 @@ func (lv Level) MarshalJSON() ([]byte, error) {
case Debug:
return []byte(`"debug"`), nil
default:
- return nil, fmt.Errorf("unknown level %v", lv)
+ return nil, fmt.Errorf("unknown level %v", l)
}
}
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal
// from both string names and integers.
-func (lv *Level) UnmarshalJSON(b []byte) error {
+func (l *Level) UnmarshalJSON(b []byte) error {
switch s := string(b); s {
case "0", `"warning"`:
- *lv = Warning
+ *l = Warning
case "1", `"info"`:
- *lv = Info
+ *l = Info
case "2", `"debug"`:
- *lv = Debug
+ *l = Debug
default:
return fmt.Errorf("unknown level %q", s)
}
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 37e0605ad..2e3408357 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -356,7 +356,7 @@ func CopyStandardLogTo(l Level) error {
case Warning:
f = Warningf
default:
- return fmt.Errorf("Unknown log level %v", l)
+ return fmt.Errorf("unknown log level %v", l)
}
stdlog.SetOutput(linewriter.NewWriter(func(p []byte) {
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 6acee90ef..aea7dde38 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -350,9 +350,13 @@ type VerifyParams struct {
// For verifyMetadata, params.data is not needed. It only accesses params.tree
// for the raw root hash.
func verifyMetadata(params *VerifyParams, layout *Layout) error {
- root := make([]byte, layout.digestSize)
- if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
- return fmt.Errorf("failed to read root hash: %w", err)
+ var root []byte
+ // Only read the root hash if we expect that the Merkle tree file is non-empty.
+ if params.Size != 0 {
+ root = make([]byte, layout.digestSize)
+ if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil {
+ return fmt.Errorf("failed to read root hash: %w", err)
+ }
}
descriptor := VerityDescriptor{
Name: params.Name,
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index abd237f46..81ceb37c5 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -296,25 +296,6 @@ func (t *Tlopen) handle(cs *connState) message {
}
defer ref.DecRef()
- ref.openedMu.Lock()
- defer ref.openedMu.Unlock()
-
- // Has it been opened already?
- if ref.opened || !CanOpen(ref.mode) {
- return newErr(syscall.EINVAL)
- }
-
- if ref.mode.IsDir() {
- // Directory must be opened ReadOnly.
- if t.Flags&OpenFlagsModeMask != ReadOnly {
- return newErr(syscall.EISDIR)
- }
- // Directory not truncatable.
- if t.Flags&OpenTruncate != 0 {
- return newErr(syscall.EISDIR)
- }
- }
-
var (
qid QID
ioUnit uint32
@@ -326,6 +307,22 @@ func (t *Tlopen) handle(cs *connState) message {
return syscall.EINVAL
}
+ // Has it been opened already?
+ if ref.opened || !CanOpen(ref.mode) {
+ return syscall.EINVAL
+ }
+
+ if ref.mode.IsDir() {
+ // Directory must be opened ReadOnly.
+ if t.Flags&OpenFlagsModeMask != ReadOnly {
+ return syscall.EISDIR
+ }
+ // Directory not truncatable.
+ if t.Flags&OpenTruncate != 0 {
+ return syscall.EISDIR
+ }
+ }
+
osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
return err
}); err != nil {
@@ -366,7 +363,7 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -437,7 +434,7 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -476,7 +473,7 @@ func (t *Tlink) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -518,7 +515,7 @@ func (t *Trenameat) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -561,7 +558,7 @@ func (t *Tunlinkat) handle(cs *connState) message {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -701,13 +698,12 @@ func (t *Tread) handle(cs *connState) message {
)
if err := ref.safelyRead(func() (err error) {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be read? Check permissions.
- if openFlags&OpenFlagsModeMask == WriteOnly {
+ if ref.openFlags&OpenFlagsModeMask == WriteOnly {
return syscall.EPERM
}
@@ -731,13 +727,12 @@ func (t *Twrite) handle(cs *connState) message {
var n int
if err := ref.safelyRead(func() (err error) {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be written? Check permissions.
- if openFlags&OpenFlagsModeMask == ReadOnly {
+ if ref.openFlags&OpenFlagsModeMask == ReadOnly {
return syscall.EPERM
}
@@ -778,7 +773,7 @@ func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -820,7 +815,7 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
}
// Not allowed on open directories.
- if _, opened := ref.OpenFlags(); opened {
+ if ref.opened {
return syscall.EINVAL
}
@@ -898,13 +893,12 @@ func (t *Tallocate) handle(cs *connState) message {
if err := ref.safelyWrite(func() error {
// Has it been opened already?
- openFlags, opened := ref.OpenFlags()
- if !opened {
+ if !ref.opened {
return syscall.EINVAL
}
// Can it be written? Check permissions.
- if openFlags&OpenFlagsModeMask == ReadOnly {
+ if ref.openFlags&OpenFlagsModeMask == ReadOnly {
return syscall.EBADF
}
@@ -1049,8 +1043,8 @@ func (t *Treaddir) handle(cs *connState) message {
return syscall.EINVAL
}
- // Has it been opened already?
- if _, opened := ref.OpenFlags(); !opened {
+ // Has it been opened yet?
+ if !ref.opened {
return syscall.EINVAL
}
@@ -1076,8 +1070,8 @@ func (t *Tfsync) handle(cs *connState) message {
defer ref.DecRef()
if err := ref.safelyRead(func() (err error) {
- // Has it been opened already?
- if _, opened := ref.OpenFlags(); !opened {
+ // Has it been opened yet?
+ if !ref.opened {
return syscall.EINVAL
}
@@ -1185,8 +1179,13 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI
}
// Has it been opened already?
- if _, opened := ref.OpenFlags(); opened {
- err = syscall.EBUSY
+ err = ref.safelyRead(func() (err error) {
+ if ref.opened {
+ return syscall.EBUSY
+ }
+ return nil
+ })
+ if err != nil {
return
}
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 6e605b14c..2e3d427ae 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -678,16 +678,15 @@ func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string
// case.
defer checkDeleted(h, dst)
} else {
+ // If the type is different than the destination, then
+ // we expect the rename to fail. We expect that this
+ // is returned.
+ //
+ // If the file being renamed to itself, this is
+ // technically allowed and a no-op, but all the
+ // triggers will fire.
if !selfRename {
- // If the type is different than the
- // destination, then we expect the rename to
- // fail. We expect ensure that this is
- // returned.
expectedErr = syscall.EINVAL
- } else {
- // This is the file being renamed to itself.
- // This is technically allowed and a no-op, but
- // all the triggers will fire.
}
dst.Close()
}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index 3736f12a3..8c5c434fd 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -134,12 +134,11 @@ type fidRef struct {
// The node above will be closed only when refs reaches zero.
refs int64
- // openedMu protects opened and openFlags.
- openedMu sync.Mutex
-
// opened indicates whether this has been opened already.
//
// This is updated in handlers.go.
+ //
+ // opened is protected by pathNode.opMu or renameMu (for write).
opened bool
// mode is the fidRef's mode from the walk. Only the type bits are
@@ -151,6 +150,8 @@ type fidRef struct {
// openFlags is the mode used in the open.
//
// This is updated in handlers.go.
+ //
+ // openFlags is protected by pathNode.opMu or renameMu (for write).
openFlags OpenFlags
// pathNode is the current pathNode for this FID.
@@ -177,13 +178,6 @@ type fidRef struct {
deleted uint32
}
-// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously.
-func (f *fidRef) OpenFlags() (OpenFlags, bool) {
- f.openedMu.Lock()
- defer f.openedMu.Unlock()
- return f.openFlags, f.opened
-}
-
// IncRef increases the references on a fid.
func (f *fidRef) IncRef() {
atomic.AddInt64(&f.refs, 1)
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index e7406b374..a29f06ddb 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -197,33 +197,33 @@ func BenchmarkSendRecv(b *testing.B) {
for i := 0; i < b.N; i++ {
tag, m, err := recv(server, maximumLength, msgRegistry.get)
if err != nil {
- b.Fatalf("recv got err %v expected nil", err)
+ b.Errorf("recv got err %v expected nil", err)
}
if tag != Tag(1) {
- b.Fatalf("got tag %v expected 1", tag)
+ b.Errorf("got tag %v expected 1", tag)
}
if _, ok := m.(*Rflush); !ok {
- b.Fatalf("got message %T expected *Rflush", m)
+ b.Errorf("got message %T expected *Rflush", m)
}
if err := send(server, Tag(2), &Rflush{}); err != nil {
- b.Fatalf("send got err %v expected nil", err)
+ b.Errorf("send got err %v expected nil", err)
}
}
}()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := send(client, Tag(1), &Rflush{}); err != nil {
- b.Fatalf("send got err %v expected nil", err)
+ b.Errorf("send got err %v expected nil", err)
}
tag, m, err := recv(client, maximumLength, msgRegistry.get)
if err != nil {
- b.Fatalf("recv got err %v expected nil", err)
+ b.Errorf("recv got err %v expected nil", err)
}
if tag != Tag(2) {
- b.Fatalf("got tag %v expected 2", tag)
+ b.Errorf("got tag %v expected 2", tag)
}
if _, ok := m.(*Rflush); !ok {
- b.Fatalf("got message %v expected *Rflush", m)
+ b.Errorf("got message %v expected *Rflush", m)
}
}
}
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
index a1b2e0cfe..54e825b28 100644
--- a/pkg/pool/pool.go
+++ b/pkg/pool/pool.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package pool provides a trivial integer pool.
package pool
import (
diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD
index bfa1daa10..0377c0876 100644
--- a/pkg/refsvfs2/BUILD
+++ b/pkg/refsvfs2/BUILD
@@ -9,7 +9,7 @@ go_template(
"refs_template.go",
],
opt_consts = [
- "logTrace",
+ "enableLogging",
],
types = [
"T",
diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go
index f64b6c6ae..3fbc91aa5 100644
--- a/pkg/refsvfs2/refs_template.go
+++ b/pkg/refsvfs2/refs_template.go
@@ -74,11 +74,6 @@ func (r *Refs) LogRefs() bool {
return enableLogging
}
-// EnableLeakCheck enables reference leak checking on r.
-func (r *Refs) EnableLeakCheck() {
- refsvfs2.Register(r)
-}
-
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
func (r *Refs) ReadRefs() int64 {
@@ -136,7 +131,7 @@ func (r *Refs) TryIncRef() bool {
func (r *Refs) DecRef(destroy func()) {
v := atomic.AddInt64(&r.refCount, -1)
if enableLogging {
- refsvfs2.LogDecRef(r, v+1)
+ refsvfs2.LogDecRef(r, v)
}
switch {
case v < 0:
@@ -153,6 +148,6 @@ func (r *Refs) DecRef(destroy func()) {
func (r *Refs) afterLoad() {
if r.ReadRefs() > 0 {
- r.EnableLeakCheck()
+ refsvfs2.Register(r)
}
}
diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go
index e7fd30743..7857f5853 100644
--- a/pkg/safemem/block_unsafe.go
+++ b/pkg/safemem/block_unsafe.go
@@ -68,29 +68,29 @@ func blockFromSlice(slice []byte, needSafecopy bool) Block {
}
}
-// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is
+// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+length), which is
// safe to access without safecopy.
//
-// Preconditions: ptr+len does not overflow.
-func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block {
- return blockFromPointer(ptr, len, false)
+// Preconditions: ptr+length does not overflow.
+func BlockFromSafePointer(ptr unsafe.Pointer, length int) Block {
+ return blockFromPointer(ptr, length, false)
}
// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which
// is not safe to access without safecopy.
//
// Preconditions: ptr+len does not overflow.
-func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block {
- return blockFromPointer(ptr, len, true)
+func BlockFromUnsafePointer(ptr unsafe.Pointer, length int) Block {
+ return blockFromPointer(ptr, length, true)
}
-func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block {
- if uptr := uintptr(ptr); uptr+uintptr(len) < uptr {
- panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len))
+func blockFromPointer(ptr unsafe.Pointer, length int, needSafecopy bool) Block {
+ if uptr := uintptr(ptr); uptr+uintptr(length) < uptr {
+ panic(fmt.Sprintf("ptr %#x + len %#x overflows", uptr, length))
}
return Block{
start: ptr,
- length: len,
+ length: length,
needSafecopy: needSafecopy,
}
}
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index 752e2dc32..ec17ebc4d 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -79,7 +79,7 @@ func Install(rules SyscallRules) error {
// Perform the actual installation.
if errno := SetFilter(instrs); errno != 0 {
- return fmt.Errorf("Failed to set filter: %v", errno)
+ return fmt.Errorf("failed to set filter: %v", errno)
}
log.Infof("Seccomp filters installed.")
diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go
index 7cd895cc7..652c010da 100644
--- a/pkg/segment/test/set_functions.go
+++ b/pkg/segment/test/set_functions.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package segment is a test package.
package segment
type setFunctions struct{}
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index d75d665ae..dd2effdf9 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint {
func (a SyscallArgument) ModeT() uint {
return uint(uint16(a.Value))
}
+
+// ErrFloatingPoint indicates a failed restore due to unusable floating point
+// state.
+type ErrFloatingPoint struct {
+ // supported is the supported floating point state.
+ supported uint64
+
+ // saved is the saved floating point state.
+ saved uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrFloatingPoint) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 19ce99d25..840e53d33 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -17,27 +17,10 @@
package arch
import (
- "fmt"
-
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/usermem"
)
-// ErrFloatingPoint indicates a failed restore due to unusable floating point
-// state.
-type ErrFloatingPoint struct {
- // supported is the supported floating point state.
- supported uint64
-
- // saved is the saved floating point state.
- saved uint64
-}
-
-// Error returns a sensible description of the restore error.
-func (e ErrFloatingPoint) Error() string {
- return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
-}
-
// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
// and SSE state, so this is the equivalent XSTATE_BV value.
const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
index c9fb55d00..35d2e07c3 100644
--- a/pkg/sentry/arch/signal.go
+++ b/pkg/sentry/arch/signal.go
@@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() {
}
}
-// Pid returns the si_pid field.
-func (s *SignalInfo) Pid() int32 {
+// PID returns the si_pid field.
+func (s *SignalInfo) PID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
}
-// SetPid mutates the si_pid field.
-func (s *SignalInfo) SetPid(val int32) {
+// SetPID mutates the si_pid field.
+func (s *SignalInfo) SetPID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
}
-// Uid returns the si_uid field.
-func (s *SignalInfo) Uid() int32 {
+// UID returns the si_uid field.
+func (s *SignalInfo) UID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
}
-// SetUid mutates the si_uid field.
-func (s *SignalInfo) SetUid(val int32) {
+// SetUID mutates the si_uid field.
+func (s *SignalInfo) SetUID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
}
@@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 {
func (s *SignalInfo) SetArch(val uint32) {
usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
}
+
+// Band returns the si_band field.
+func (s *SignalInfo) Band() int64 {
+ return int64(usermem.ByteOrder.Uint64(s.Fields[0:8]))
+}
+
+// SetBand mutates the si_band field.
+func (s *SignalInfo) SetBand(val int64) {
+ // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
+ // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
+ // `int`. See siginfo.h.
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
+}
+
+// FD returns the si_fd field.
+func (s *SignalInfo) FD() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[8:12])
+}
+
+// SetFD mutates the si_fd field.
+func (s *SignalInfo) SetFD(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], val)
+}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 2bf3c45e1..2f3664c57 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -15,10 +15,10 @@
package control
import (
- "errors"
"runtime"
"runtime/pprof"
"runtime/trace"
+ "time"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -26,184 +26,263 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
-var errNoOutput = errors.New("no output writer provided")
+// Profile includes profile-related RPC stubs. It provides a way to
+// control the built-in runtime profiling facilities.
+//
+// The profile object must be instantied via NewProfile.
+type Profile struct {
+ // kernel is the kernel under profile. It's immutable.
+ kernel *kernel.Kernel
-// ProfileOpts contains options for the StartCPUProfile/Goroutine RPC call.
-type ProfileOpts struct {
- // File is the filesystem path for the profile.
- File string `json:"path"`
+ // cpuMu protects CPU profiling.
+ cpuMu sync.Mutex
- // FilePayload is the destination for the profiling output.
- urpc.FilePayload
+ // blockMu protects block profiling.
+ blockMu sync.Mutex
+
+ // mutexMu protects mutex profiling.
+ mutexMu sync.Mutex
+
+ // traceMu protects trace profiling.
+ traceMu sync.Mutex
+
+ // done is closed when profiling is done.
+ done chan struct{}
}
-// Profile includes profile-related RPC stubs. It provides a way to
-// control the built-in pprof facility in sentry via sentryctl.
-//
-// The following options to sentryctl are added:
-//
-// - collect CPU profile on-demand.
-// sentryctl -pid <pid> pprof-cpu-start
-// sentryctl -pid <pid> pprof-cpu-stop
-//
-// - dump out the stack trace of current go routines.
-// sentryctl -pid <pid> pprof-goroutine
-type Profile struct {
- // Kernel is the kernel under profile. It's immutable.
- Kernel *kernel.Kernel
+// NewProfile returns a new Profile object.
+func NewProfile(k *kernel.Kernel) *Profile {
+ return &Profile{
+ kernel: k,
+ done: make(chan struct{}),
+ }
+}
- // mu protects the fields below.
- mu sync.Mutex
+// Stop implements urpc.Stopper.Stop.
+func (p *Profile) Stop() {
+ close(p.done)
+}
- // cpuFile is the current CPU profile output file.
- cpuFile *fd.FD
+// CPUProfileOpts contains options specifically for CPU profiles.
+type CPUProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
- // traceFile is the current execution trace output file.
- traceFile *fd.FD
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
}
-// StartCPUProfile is an RPC stub which starts recording the CPU profile in a
-// file.
-func (p *Profile) StartCPUProfile(o *ProfileOpts, _ *struct{}) error {
+// CPU is an RPC stub which collects a CPU profile.
+func (p *Profile) CPU(o *CPUProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
- output, err := fd.NewFromFile(o.FilePayload.Files[0])
- if err != nil {
- return err
- }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.cpuMu.Lock()
+ defer p.cpuMu.Unlock()
// Returns an error if profiling is already started.
if err := pprof.StartCPUProfile(output); err != nil {
- output.Close()
return err
}
+ defer pprof.StopCPUProfile()
+
+ // Collect the profile.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
+ }
- p.cpuFile = output
return nil
}
-// StopCPUProfile is an RPC stub which stops the CPU profiling and flush out the
-// profile data. It takes no argument.
-func (p *Profile) StopCPUProfile(_, _ *struct{}) error {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if p.cpuFile == nil {
- return errors.New("CPU profiling not started")
- }
+// HeapProfileOpts contains options specifically for heap profiles.
+type HeapProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
- pprof.StopCPUProfile()
- p.cpuFile.Close()
- p.cpuFile = nil
- return nil
+ // Delay is the sleep time, similar to Duration. This may
+ // not affect the data collected however, as the heap will
+ // continue only the memory associated with the last alloc.
+ Delay time.Duration `json:"delay"`
}
-// HeapProfile generates a heap profile for the sentry.
-func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error {
+// Heap generates a heap profile.
+func (p *Profile) Heap(o *HeapProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- runtime.GC() // Get up-to-date statistics.
- if err := pprof.WriteHeapProfile(output); err != nil {
- return err
+
+ // Wait for the given delay.
+ select {
+ case <-time.After(o.Delay):
+ case <-p.done:
}
- return nil
+
+ // Get up-to-date statistics.
+ runtime.GC()
+
+ // Write the given profile.
+ return pprof.WriteHeapProfile(output)
+}
+
+// GoroutineProfileOpts contains options specifically for goroutine profiles.
+type GoroutineProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
}
-// GoroutineProfile is an RPC stub which dumps out the stack trace for all
-// running goroutines.
-func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error {
+// Goroutine dumps out the stack trace for all running goroutines.
+func (p *Profile) Goroutine(o *GoroutineProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil {
- return err
- }
- return nil
+
+ return pprof.Lookup("goroutine").WriteTo(output, 2)
+}
+
+// BlockProfileOpts contains options specifically for block profiles.
+type BlockProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
+
+ // Rate is the block profile rate.
+ Rate int `json:"rate"`
}
-// BlockProfile is an RPC stub which dumps out the stack trace that led to
-// blocking on synchronization primitives.
-func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error {
+// Block dumps a blocking profile.
+func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- if err := pprof.Lookup("block").WriteTo(output, 0); err != nil {
- return err
+
+ p.blockMu.Lock()
+ defer p.blockMu.Unlock()
+
+ // Always set the rate. We then wait to collect a profile at this rate,
+ // and disable when we're done. Note that the default here is 10%, which
+ // will record a stacktrace 10% of the time when blocking occurs. Since
+ // these events should not be super frequent, we expect this to achieve
+ // a reasonable balance between collecting the data we need and imposing
+ // a high performance cost (e.g. skewing even the CPU profile).
+ rate := 10
+ if o.Rate != 0 {
+ rate = o.Rate
}
- return nil
+ runtime.SetBlockProfileRate(rate)
+ defer runtime.SetBlockProfileRate(0)
+
+ // Collect the profile.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
+ }
+
+ return pprof.Lookup("block").WriteTo(output, 0)
+}
+
+// MutexProfileOpts contains options specifically for mutex profiles.
+type MutexProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
+
+ // Fraction is the mutex profile fraction.
+ Fraction int `json:"fraction"`
}
-// MutexProfile is an RPC stub which dumps out the stack trace of holders of
-// contended mutexes.
-func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error {
+// Mutex dumps a mutex profile.
+func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil {
- return err
+
+ p.mutexMu.Lock()
+ defer p.mutexMu.Unlock()
+
+ // Always set the fraction. Like the block rate above, we use
+ // a default rate of 10% for the same reasons.
+ fraction := 10
+ if o.Fraction != 0 {
+ fraction = o.Fraction
}
- return nil
+ runtime.SetMutexProfileFraction(fraction)
+ defer runtime.SetMutexProfileFraction(0)
+
+ // Collect the profile.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
+ }
+
+ return pprof.Lookup("mutex").WriteTo(output, 0)
}
-// StartTrace is an RPC stub which starts collection of an execution trace.
-func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
+// TraceProfileOpts contains options specifically for traces.
+type TraceProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
+}
+
+// Trace is an RPC stub which starts collection of an execution trace.
+func (p *Profile) Trace(o *TraceProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
output, err := fd.NewFromFile(o.FilePayload.Files[0])
if err != nil {
return err
}
+ defer output.Close()
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.traceMu.Lock()
+ defer p.traceMu.Unlock()
// Returns an error if profiling is already started.
if err := trace.Start(output); err != nil {
output.Close()
return err
}
+ defer trace.Stop()
// Ensure all trace contexts are registered.
- p.Kernel.RebuildTraceContexts()
-
- p.traceFile = output
- return nil
-}
-
-// StopTrace is an RPC stub which stops collection of an ongoing execution
-// trace and flushes the trace data. It takes no argument.
-func (p *Profile) StopTrace(_, _ *struct{}) error {
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.kernel.RebuildTraceContexts()
- if p.traceFile == nil {
- return errors.New("Execution tracing not started")
+ // Wait for the trace.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
}
// Similarly to the case above, if tasks have not ended traces, we will
// lose information. Thus we need to rebuild the tasks in order to have
// complete information. This will not lose information if multiple
// traces are overlapping.
- p.Kernel.RebuildTraceContexts()
+ p.kernel.RebuildTraceContexts()
- trace.Stop()
- p.traceFile.Close()
- p.traceFile = nil
return nil
}
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
index d800f2c85..62eaca965 100644
--- a/pkg/sentry/control/state.go
+++ b/pkg/sentry/control/state.go
@@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error {
Callback: func(err error) {
if err == nil {
log.Infof("Save succeeded: exiting...")
+ s.Kernel.SetSaveSuccess(false /* autosave */)
} else {
log.Warningf("Save failed: exiting...")
s.Kernel.SetSaveError(err)
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
index 314661475..badd5b073 100644
--- a/pkg/sentry/fdimport/fdimport.go
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package fdimport provides the Import function.
package fdimport
import (
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index ff2fe6712..8e0aa9019 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err
// copyUpBuffers is a buffer pool for copying file content. The buffer
// size is the same used by io.Copy.
-var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }}
+var copyUpBuffers = sync.Pool{
+ New: func() interface{} {
+ b := make([]byte, 8*usermem.PageSize)
+ return &b
+ },
+}
// copyContentsLocked copies the contents of lower to upper. It panics if
// less than size bytes can be copied.
@@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
defer lowerFile.DecRef(ctx)
// Use a buffer pool to minimize allocations.
- buf := copyUpBuffers.Get().([]byte)
+ buf := copyUpBuffers.Get().(*[]byte)
defer copyUpBuffers.Put(buf)
// Transfer the contents.
@@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
// optimizations could be self-defeating. So we leave this as simple as possible.
var offset int64
for {
- nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset)
+ nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset)
if err != nil && err != io.EOF {
return err
}
@@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
}
return nil
}
- nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset)
+ nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset)
if err != nil {
return err
}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index c7a11eec1..e04784db2 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) {
wg.Add(1)
go func(o *overlayTestFile) {
if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil {
- t.Fatalf("failed to copy up: %v", err)
+ t.Errorf("failed to copy up: %v", err)
}
wg.Done()
}(file)
diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go
index 8049538f2..ec3d3f96c 100644
--- a/pkg/sentry/fs/filetest/filetest.go
+++ b/pkg/sentry/fs/filetest/filetest.go
@@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File {
// Read just fails the request.
func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Readv not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Read not implemented")
}
// Write just fails the request.
func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Writev not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Write not implemented")
}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index d2dbff268..a020da53b 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -65,7 +65,7 @@ var (
// runs with the lock held for reading. AsyncBarrier will take the lock
// for writing, thus ensuring that all Async work completes before
// AsyncBarrier returns.
- workMu sync.RWMutex
+ workMu sync.CrossGoroutineRWMutex
// asyncError is used to store up to one asynchronous execution error.
asyncError = make(chan error, 1)
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
index d481baf77..e5579095b 100644
--- a/pkg/sentry/fs/gofer/attr.go
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType {
return fs.BlockDevice
case pattr.Mode.IsSocket():
return fs.Socket
- case pattr.Mode.IsRegular():
- fallthrough
default:
return fs.RegularFile
}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 9d6fdd08f..e840b6f5e 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -475,6 +475,9 @@ func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermM
func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
switch d.Inode.StableAttr.Type {
case fs.Socket:
+ if i.session().overrides != nil {
+ return nil, syserror.ENXIO
+ }
return i.getFileSocket(ctx, d, flags)
case fs.Pipe:
return i.getFilePipe(ctx, d, flags)
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index fbfba1b58..2c14aa6d9 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -276,6 +276,10 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.
// GetFile implements fs.InodeOperations.GetFile.
func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ if fs.IsSocket(d.Inode.StableAttr) {
+ return nil, syserror.ENXIO
+ }
+
return newFile(ctx, d, flags, i), nil
}
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index f8aad2dbd..b998fb75d 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
children := map[string]*fs.Inode{
"hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil),
+ "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
"shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
"shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go
index 29ff004f2..d0c565879 100644
--- a/pkg/sentry/fs/ramfs/socket.go
+++ b/pkg/sentry/fs/ramfs/socket.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -63,7 +64,7 @@ func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint {
// GetFile implements fs.FileOperations.GetFile.
func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil
+ return nil, syserror.ENXIO
}
// +stateify savable
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index e04cd608d..ad4aea282 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -148,6 +148,10 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare
// GetFile implements fs.InodeOperations.GetFile.
func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ if fs.IsSocket(d.Inode.StableAttr) {
+ return nil, syserror.ENXIO
+ }
+
if flags.Write {
fsmetric.TmpfsOpensW.Increment()
} else if flags.Read {
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 9009ba3c7..4a555bf72 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -200,7 +200,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
}
var fd symlinkFD
fd.LockFD.Init(&in.locks)
- fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
return &fd.vfsfd, nil
default:
panic(fmt.Sprintf("unknown inode type: %T", in.impl))
diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go
index 1b3459c1d..4ab894965 100644
--- a/pkg/sentry/fsimpl/fuse/connection_control.go
+++ b/pkg/sentry/fsimpl/fuse/connection_control.go
@@ -84,11 +84,7 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
Flags: fuseDefaultInitFlags,
}
- req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
// Since there is no task to block on and FUSE_INIT is the request
// to unblock other requests, use nil.
return conn.CallAsync(nil, req)
diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go
index 91d16c1cf..d8b0d7657 100644
--- a/pkg/sentry/fsimpl/fuse/connection_test.go
+++ b/pkg/sentry/fsimpl/fuse/connection_test.go
@@ -76,10 +76,7 @@ func TestConnectionAbort(t *testing.T) {
var futNormal []*futureResponse
for i := 0; i < int(numRequests); i++ {
- req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
fut, err := conn.callFutureLocked(task, req)
if err != nil {
t.Fatalf("callFutureLocked failed: %v", err)
@@ -105,10 +102,7 @@ func TestConnectionAbort(t *testing.T) {
}
// After abort, Call() should return directly with ENOTCONN.
- req, err := conn.NewRequest(creds, 0, 0, 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, 0, 0, 0, testObj)
_, err = conn.Call(task, req)
if err != syserror.ENOTCONN {
t.Fatalf("Incorrect error code received for Call() after connection aborted")
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 95c475a65..bb2d0d31a 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -219,10 +219,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con
data: rand.Uint32(),
}
- req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
// Queue up a request.
// Analogous to Call except it doesn't block on the task.
diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go
index 8f220a04b..fcc5d9a2a 100644
--- a/pkg/sentry/fsimpl/fuse/directory.go
+++ b/pkg/sentry/fsimpl/fuse/directory.go
@@ -68,11 +68,7 @@ func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirent
}
// TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS.
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
res, err := fusefs.conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go
index 83f2816b7..e138b11f8 100644
--- a/pkg/sentry/fsimpl/fuse/file.go
+++ b/pkg/sentry/fsimpl/fuse/file.go
@@ -83,12 +83,8 @@ func (fd *fileDescription) Release(ctx context.Context) {
opcode = linux.FUSE_RELEASE
}
kernelTask := kernel.TaskFromContext(ctx)
- // ignoring errors and FUSE server reply is analogous to Linux's behavior.
- req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
- if err != nil {
- // No way to invoke Call() with an errored request.
- return
- }
+ // Ignoring errors and FUSE server reply is analogous to Linux's behavior.
+ req := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
// The reply will be ignored since no callback is defined in asyncCallBack().
conn.CallAsync(kernelTask, req)
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 23e827f90..204d8d143 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -119,7 +119,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
if err != nil {
- return nil, nil, err
+ log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err)
+ return nil, nil, syserror.EINVAL
}
kernelTask := kernel.TaskFromContext(ctx)
@@ -128,6 +129,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, syserror.EINVAL
}
fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor))
+ if fuseFDGeneric == nil {
+ return nil, nil, syserror.EINVAL
+ }
defer fuseFDGeneric.DecRef(ctx)
fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD)
if !ok {
@@ -360,12 +364,8 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
in.Flags &= ^uint32(linux.O_TRUNC)
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
- if err != nil {
- return nil, err
- }
-
// Send the request and receive the reply.
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -485,10 +485,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err
return syserror.EINVAL
}
in := linux.FUSEUnlinkIn{Name: name}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
- if err != nil {
- return err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return err
@@ -515,11 +512,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro
task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
in := linux.FUSERmDirIn{Name: name}
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return err
@@ -535,10 +528,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
return nil, syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
- if err != nil {
- return nil, err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -574,10 +564,7 @@ func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context")
return "", syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
- if err != nil {
- return "", err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return "", err
@@ -680,11 +667,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
GetAttrFlags: flags,
Fh: fh,
}
- req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
- if err != nil {
- return linux.FUSEAttr{}, err
- }
-
+ req := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return linux.FUSEAttr{}, err
@@ -803,11 +786,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
UID: opts.Stat.UID,
GID: opts.Stat.GID,
}
- req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
res, err := conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 2d396e84c..23ce91849 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -79,13 +79,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
in.Offset = off + (uint64(pagesRead) << usermem.PageShift)
in.Size = pagesCanRead << usermem.PageShift
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
- if err != nil {
- return nil, 0, err
- }
-
// TODO(gvisor.dev/issue/3247): support async read.
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
res, err := fs.conn.Call(t, req)
if err != nil {
return nil, 0, err
@@ -204,11 +200,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
in.Offset = off + uint64(written)
in.Size = toWrite
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
- if err != nil {
- return 0, err
- }
-
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
req.payload = data[written : written+toWrite]
// TODO(gvisor.dev/issue/3247): support async write.
diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go
index 7fa00569b..41d679358 100644
--- a/pkg/sentry/fsimpl/fuse/request_response.go
+++ b/pkg/sentry/fsimpl/fuse/request_response.go
@@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) {
out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2]))
src = src[2:]
}
+ _ = src // Remove unused warning.
}
// SizeBytes is the size of the payload of the FUSE_INIT response.
@@ -104,7 +105,7 @@ type Request struct {
}
// NewRequest creates a new request that can be sent to the FUSE server.
-func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) *Request {
conn.fd.mu.Lock()
defer conn.fd.mu.Unlock()
conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
@@ -130,7 +131,7 @@ func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint
id: hdr.Unique,
hdr: &hdr,
data: buf,
- }, nil
+ }
}
// futureResponse represents an in-flight request, that may or may not have
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 435a21d77..36a3f6810 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -31,6 +31,7 @@ import (
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -499,6 +500,10 @@ func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flag
fileDescription: fileDescription{inode: i},
termios: linux.DefaultReplicaTermios,
}
+ if task := kernel.TaskFromContext(ctx); task != nil {
+ fd.fgProcessGroup = task.ThreadGroup().ProcessGroup()
+ fd.session = fd.fgProcessGroup.Session()
+ }
fd.LockFD.Init(&i.locks)
vfsfd := &fd.vfsfd
if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 469f3a33d..27b00cf6f 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -16,7 +16,6 @@ package overlay
import (
"fmt"
- "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
return err
}
defer newFD.DecRef(ctx)
- bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
- for {
- readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
- if readErr != nil && readErr != io.EOF {
- cleanupUndoCopyUp()
- return readErr
- }
- total := int64(0)
- for total < readN {
- writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{})
- total += writeN
- if writeErr != nil {
- cleanupUndoCopyUp()
- return writeErr
- }
- }
- if readErr == io.EOF {
- break
- }
+ if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil {
+ cleanupUndoCopyUp()
+ return err
}
d.mapsMu.Lock()
defer d.mapsMu.Unlock()
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 2b89a7a6d..25c785fd4 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript
for e, mask := range fd.lowerWaiters {
fd.cachedFD.EventUnregister(e)
upperFD.EventRegister(e, mask)
- if ready&mask != 0 {
- e.Callback.Callback(e)
+ if m := ready & mask; m != 0 {
+ e.Callback.Callback(e, m)
}
}
}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 0ecb592cf..429733c10 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -164,11 +164,11 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e
// and write ends of a newly-created pipe, as for pipe(2) and pipe2(2).
//
// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
-func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) {
fs := mnt.Filesystem().Impl().(*filesystem)
inode := newInode(ctx, fs)
var d kernfs.Dentry
d.Init(&fs.Filesystem, inode)
defer d.DecRef(ctx)
- return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
+ return inode.pipe.ReaderWriterPair(ctx, mnt, d.VFSDentry(), flags)
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index a3780b222..75be6129f 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -57,9 +57,6 @@ func getMM(task *kernel.Task) *mm.MemoryManager {
// MemoryManager's users count is incremented, and must be decremented by the
// caller when it is no longer in use.
func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
- if task.ExitState() == kernel.TaskExitDead {
- return nil, syserror.ESRCH
- }
var m *mm.MemoryManager
task.WithMuLocked(func(t *kernel.Task) {
m = t.MemoryManager()
@@ -111,9 +108,13 @@ var _ dynamicInode = (*auxvData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if d.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
m, err := getMMIncRef(d.task)
if err != nil {
- return err
+ // Return empty file.
+ return nil
}
defer m.DecUsers(ctx)
@@ -157,9 +158,13 @@ var _ dynamicInode = (*cmdlineData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if d.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
m, err := getMMIncRef(d.task)
if err != nil {
- return err
+ // Return empty file.
+ return nil
}
defer m.DecUsers(ctx)
@@ -472,7 +477,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64
}
m, err := getMMIncRef(fd.inode.task)
if err != nil {
- return 0, nil
+ return 0, err
}
defer m.DecUsers(ctx)
// Buffer the read data because of MM locks
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 7c7afdcfa..25c407d98 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -44,6 +44,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
+ "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
"shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
"shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index 10f1452ef..246bd87bc 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package signalfd provides basic signalfd file implementations.
package signalfd
import (
@@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 59fcff498..a4ad625bb 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -163,7 +163,7 @@ afterSymlink:
// verifyChildLocked verifies the hash of child against the already verified
// hash of the parent to ensure the child is expected. verifyChild triggers a
// sentry panic if unexpected modifications to the file system are detected. In
-// noCrashOnVerificationFailure mode it returns a syserror instead.
+// ErrorOnViolation mode it returns a syserror instead.
//
// Preconditions:
// * fs.renameMu must be locked.
@@ -254,7 +254,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: parentMerkleFD,
Ctx: ctx,
}
@@ -397,7 +397,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
}
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: fd,
Ctx: ctx,
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index add65bee6..a5171b5ad 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -64,6 +64,10 @@ const (
// tree file for "/foo" is "/.merkle.verity.foo".
merklePrefix = ".merkle.verity."
+ // merkleRootPrefix is the prefix of the Merkle tree root file. This
+ // needs to be different from merklePrefix to avoid name collision.
+ merkleRootPrefix = ".merkleroot.verity."
+
// merkleOffsetInParentXattr is the extended attribute name specifying the
// offset of the child hash in its parent's Merkle tree.
merkleOffsetInParentXattr = "user.merkle.offset"
@@ -88,10 +92,8 @@ const (
)
var (
- // noCrashOnVerificationFailure indicates whether the sandbox should panic
- // whenever verification fails. If true, an error is returned instead of
- // panicking. This should only be set for tests.
- noCrashOnVerificationFailure bool
+ // action specifies the action towards detected violation.
+ action ViolationAction
// verityMu synchronizes concurrent operations that enable verity and perform
// verification checks.
@@ -102,6 +104,18 @@ var (
// content.
type HashAlgorithm int
+// ViolationAction is a type specifying the action when an integrity violation
+// is detected.
+type ViolationAction int
+
+const (
+ // PanicOnViolation terminates the sentry on detected violation.
+ PanicOnViolation ViolationAction = 0
+ // ErrorOnViolation returns an error from the violating system call on
+ // detected violation.
+ ErrorOnViolation = 1
+)
+
// Currently supported hashing algorithms include SHA256 and SHA512.
const (
SHA256 HashAlgorithm = iota
@@ -166,7 +180,7 @@ type filesystem struct {
// its children. So they shouldn't be enabled the same time. This lock
// is for the whole file system to ensure that no more than one file is
// enabled the same time.
- verityMu sync.RWMutex
+ verityMu sync.RWMutex `state:"nosave"`
}
// InternalFilesystemOptions may be passed as
@@ -196,10 +210,8 @@ type InternalFilesystemOptions struct {
// system wrapped by verity file system.
LowerGetFSOptions vfs.GetFilesystemOptions
- // NoCrashOnVerificationFailure indicates whether the sandbox should
- // panic whenever verification fails. If true, an error is returned
- // instead of panicking. This should only be set for tests.
- NoCrashOnVerificationFailure bool
+ // Action specifies the action on an integrity violation.
+ Action ViolationAction
}
// Name implements vfs.FilesystemType.Name.
@@ -211,10 +223,10 @@ func (FilesystemType) Name() string {
func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
-// unexpected modification to the file system is detected. In
-// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic.
+// unexpected modification to the file system is detected. In ErrorOnViolation
+// mode, it returns EIO, otherwise it panic.
func alertIntegrityViolation(msg string) error {
- if noCrashOnVerificationFailure {
+ if action == ErrorOnViolation {
return syserror.EIO
}
panic(msg)
@@ -227,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
return nil, nil, syserror.EINVAL
}
- noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure
+ action = iopts.Action
// Mount the lower file system. The lower file system is wrapped inside
// verity, and should not be exposed or connected.
@@ -255,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
lowerVD.IncRef()
d.lowerVD = lowerVD
- rootMerkleName := merklePrefix + iopts.RootMerkleFileName
+ rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName
lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
@@ -744,20 +756,20 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32)
// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The
// hash of the generated Merkle tree and the data size is returned. If fd
// points to a regular file, the data is the content of the file. If fd points
-// to a directory, the data is all hahes of its children, written to the Merkle
+// to a directory, the data is all hashes of its children, written to the Merkle
// tree file.
//
// Preconditions: fd.d.fs.verityMu must be locked.
func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) {
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
- merkleWriter := vfs.FileReadWriteSeeker{
+ merkleWriter := FileReadWriteSeeker{
FD: fd.merkleWriter,
Ctx: ctx,
}
@@ -1047,12 +1059,12 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
- dataReader := vfs.FileReadWriteSeeker{
+ dataReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
@@ -1101,3 +1113,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t
func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence)
}
+
+// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
+// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
+type FileReadWriteSeeker struct {
+ FD *vfs.FileDescription
+ Ctx context.Context
+ ROpts vfs.ReadOptions
+ WOpts vfs.WriteOptions
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
+ return int(n), err
+}
+
+// Read implements io.ReadWriteSeeker.Read.
+func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
+ return int(n), err
+}
+
+// Seek implements io.ReadWriteSeeker.Seek.
+func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
+ return f.FD.Seek(f.Ctx, offset, int32(whence))
+}
+
+// WriteAt implements io.WriterAt.WriteAt.
+func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
+ return int(n), err
+}
+
+// Write implements io.ReadWriteSeeker.Write.
+func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
+ buf := usermem.BytesIOSequence(p)
+ n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
+ return int(n), err
+}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 5d1f5de08..30d8b4355 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -35,14 +35,16 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// rootMerkleFilename is the name of the root Merkle tree file.
-const rootMerkleFilename = "root.verity"
+const (
+ // rootMerkleFilename is the name of the root Merkle tree file.
+ rootMerkleFilename = "root.verity"
+ // maxDataSize is the maximum data size of a test file.
+ maxDataSize = 100000
+)
-// maxDataSize is the maximum data size written to the file for test.
-const maxDataSize = 100000
+var hashAlgs = []HashAlgorithm{SHA256, SHA512}
-// getD returns a *dentry corresponding to VD.
-func getD(t *testing.T, vd vfs.VirtualDentry) *dentry {
+func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry {
t.Helper()
d, ok := vd.Dentry().Impl().(*dentry)
if !ok {
@@ -51,10 +53,21 @@ func getD(t *testing.T, vd vfs.VirtualDentry) *dentry {
return d
}
+// dentryFromFD returns the dentry corresponding to fd.
+func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry {
+ t.Helper()
+ f, ok := fd.Impl().(*fileDescription)
+ if !ok {
+ t.Fatalf("can't assert %T as a *fileDescription", fd)
+ }
+ return f.d
+}
+
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+ t.Helper()
k, err := testutil.Boot()
if err != nil {
t.Fatalf("testutil.Boot: %v", err)
@@ -79,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
InternalData: InternalFilesystemOptions{
- RootMerkleFileName: rootMerkleFilename,
- LowerName: "tmpfs",
- Alg: hashAlg,
- AllowRuntimeEnable: true,
- NoCrashOnVerificationFailure: true,
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ Alg: hashAlg,
+ AllowRuntimeEnable: true,
+ Action: ErrorOnViolation,
},
},
})
@@ -102,7 +115,6 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
t.Fatalf("testutil.CreateTask: %v", err)
}
- t.Helper()
t.Cleanup(func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
@@ -111,6 +123,8 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
}
// openVerityAt opens a verity file.
+//
+// TODO(chongc): release reference from opening the file when done.
func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
Root: vd,
@@ -123,6 +137,8 @@ func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.Vir
}
// openLowerAt opens the file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
Root: d.lowerVD,
@@ -135,6 +151,8 @@ func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem,
}
// openLowerMerkleAt opens the Merkle file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
Root: d.lowerMerkleVD,
@@ -190,21 +208,11 @@ func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFil
}, &vfs.RenameOptions{})
}
-// getDentry returns a *dentry corresponds to fd.
-func getDentry(t *testing.T, fd *vfs.FileDescription) *dentry {
- t.Helper()
- f, ok := fd.Impl().(*fileDescription)
- if !ok {
- t.Fatalf("can't assert %T as a *fileDescription", fd)
- }
- return f.d
-}
-
// newFileFD creates a new file in the verity mount, and returns the FD. The FD
// points to a file that has random data generated.
func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
// Create the file in the underlying file system.
- lowerFD, err := getD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
+ lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
if err != nil {
return nil, 0, err
}
@@ -231,9 +239,20 @@ func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem,
return fd, dataSize, err
}
-// corruptRandomBit randomly flips a bit in the file represented by fd.
-func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
- // Flip a random bit in the underlying file.
+// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD.
+func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) {
+ // Create the file in the underlying file system.
+ _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
+ if err != nil {
+ return nil, err
+ }
+ // Now open the verity file descriptor.
+ fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode)
+ return fd, err
+}
+
+// flipRandomBit randomly flips a bit in the file represented by fd.
+func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
randomPos := int64(rand.Intn(size))
byteToModify := make([]byte, 1)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
@@ -246,7 +265,14 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
return nil
}
-var hashAlgs = []HashAlgorithm{SHA256, SHA512}
+func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ t.Helper()
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("enable verity: %v", err)
+ }
+}
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
@@ -264,12 +290,12 @@ func TestOpen(t *testing.T) {
}
// Ensure that the corresponding Merkle tree file is created.
- if _, err = getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
+ if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
}
// Ensure the root merkle tree file is created.
- if _, err = getD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
+ if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
}
}
@@ -291,11 +317,7 @@ func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
@@ -325,11 +347,7 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
@@ -343,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
}
+// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity
+// file succeeds after enabling verity for it.
+func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-empty-file"
+ fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newEmptyFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ enableVerity(ctx, t, fd)
+
+ var buf []byte
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != 0 {
+ t.Errorf("fd.Read got read length %d, expected 0", n)
+ }
+ }
+}
+
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
@@ -359,11 +407,7 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Ensure reopening the verity enabled file succeeds.
if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil {
@@ -387,21 +431,14 @@ func TestOpenNonexistentFile(t *testing.T) {
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Ensure open an unexpected file in the parent directory fails with
// ENOENT rather than verification failure.
@@ -426,20 +463,16 @@ func TestPReadModifiedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -466,20 +499,16 @@ func TestReadModifiedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -506,14 +535,10 @@ func TestModifiedMerkleFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerMerkleFD that's read/writable.
- lowerMerkleFD, err := getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
+ lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -524,14 +549,13 @@ func TestModifiedMerkleFails(t *testing.T) {
t.Errorf("lowerMerkleFD.Stat: %v", err)
}
- if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from a file with modified Merkle tree fails.
buf := make([]byte, size)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- fmt.Println(buf)
t.Fatalf("fd.PRead succeeded with modified Merkle file")
}
}
@@ -554,24 +578,17 @@ func TestModifiedParentMerkleFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Open a new lowerMerkleFD that's read/writable.
- parentLowerMerkleFD, err := getDentry(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
+ parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -591,8 +608,8 @@ func TestModifiedParentMerkleFails(t *testing.T) {
if err != nil {
t.Fatalf("Failed convert size to int: %v", err)
}
- if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
parentLowerMerkleFD.DecRef(ctx)
@@ -619,13 +636,8 @@ func TestUnmodifiedStatSucceeds(t *testing.T) {
t.Fatalf("newFileFD: %v", err)
}
- // Enable verity on the file and confirms stat succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
+ // Enable verity on the file and confirm that stat succeeds.
+ enableVerity(ctx, t, fd)
if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
t.Errorf("fd.Stat: %v", err)
}
@@ -648,11 +660,7 @@ func TestModifiedStatFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
lowerFD := fd.Impl().(*fileDescription).lowerFD
// Change the stat of the underlying file, and check that stat fails.
@@ -711,19 +719,15 @@ func TestOpenDeletedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
if tc.changeFile {
- if err := getD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := getD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
@@ -776,20 +780,16 @@ func TestOpenRenamedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
newFilename := "renamed-test-file"
if tc.changeFile {
- if err := getD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil {
+ if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("RenameAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := getD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil {
+ if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 15519f0df..61aeca044 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
//
// Callback is called when one of the files we're polling becomes ready. It
// moves said file to the readyList if it's currently in the waiting list.
-func (p *pollEntry) Callback(*waiter.Entry) {
+func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) {
e := p.epoll
e.listsMu.Lock()
@@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
f.EventRegister(&entry.waiter, entry.mask)
// Check if the file happens to already be in a ready state.
- ready := f.Readiness(entry.mask) & entry.mask
- if ready != 0 {
- entry.Callback(&entry.waiter)
+ if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 {
+ entry.Callback(&entry.waiter, ready)
}
}
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 2b3955598..f855f038b 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -8,11 +8,13 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
+ "//pkg/syserror",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index 153d2cd9b..b66d61c6f 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -17,22 +17,45 @@ package fasync
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
-// New creates a new fs.FileAsync.
-func New() fs.FileAsync {
- return &FileAsync{}
+// Table to convert waiter event masks into si_band siginfo codes.
+// Taken from fs/fcntl.c:band_table.
+var bandTable = map[waiter.EventMask]int64{
+ // POLL_IN
+ waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM,
+ // POLL_OUT
+ waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND,
+ // POLL_ERR
+ waiter.EventErr: linux.EPOLLERR,
+ // POLL_PRI
+ waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND,
+ // POLL_HUP
+ waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR,
}
-// NewVFS2 creates a new vfs.FileAsync.
-func NewVFS2() vfs.FileAsync {
- return &FileAsync{}
+// New returns a function that creates a new fs.FileAsync with the given file
+// descriptor.
+func New(fd int) func() fs.FileAsync {
+ return func() fs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
+}
+
+// NewVFS2 returns a function that creates a new vfs.FileAsync with the given
+// file descriptor.
+func NewVFS2(fd int) func() vfs.FileAsync {
+ return func() vfs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
}
// FileAsync sends signals when the registered file is ready for IO.
@@ -42,6 +65,12 @@ type FileAsync struct {
// e is immutable after first use (which is protected by mu below).
e waiter.Entry
+ // fd is the file descriptor to notify about.
+ // It is immutable, set at allocation time. This matches Linux semantics in
+ // fs/fcntl.c:fasync_helper.
+ // The fd value is passed to the signal recipient in siginfo.si_fd.
+ fd int
+
// regMu protects registeration and unregistration actions on e.
//
// regMu must be held while registration decisions are being made
@@ -56,6 +85,10 @@ type FileAsync struct {
mu sync.Mutex `state:"nosave"`
requester *auth.Credentials
registered bool
+ // signal is the signal to deliver upon I/O being available.
+ // The default value ("zero signal") means the default SIGIO signal will be
+ // delivered.
+ signal linux.Signal
// Only one of the following is allowed to be non-nil.
recipientPG *kernel.ProcessGroup
@@ -64,10 +97,10 @@ type FileAsync struct {
}
// Callback sends a signal.
-func (a *FileAsync) Callback(e *waiter.Entry) {
+func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) {
a.mu.Lock()
+ defer a.mu.Unlock()
if !a.registered {
- a.mu.Unlock()
return
}
t := a.recipientT
@@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) {
}
if t == nil {
// No recipient has been registered.
- a.mu.Unlock()
return
}
c := t.Credentials()
// Logic from sigio_perm in fs/fcntl.c.
- if a.requester.EffectiveKUID == 0 ||
+ permCheck := (a.requester.EffectiveKUID == 0 ||
a.requester.EffectiveKUID == c.SavedKUID ||
a.requester.EffectiveKUID == c.RealKUID ||
a.requester.RealKUID == c.SavedKUID ||
- a.requester.RealKUID == c.RealKUID {
- t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO))
+ a.requester.RealKUID == c.RealKUID)
+ if !permCheck {
+ return
}
- a.mu.Unlock()
+ signalInfo := &arch.SignalInfo{
+ Signo: int32(linux.SIGIO),
+ Code: arch.SignalInfoKernel,
+ }
+ if a.signal != 0 {
+ signalInfo.Signo = int32(a.signal)
+ signalInfo.SetFD(uint32(a.fd))
+ var band int64
+ for m, bandCode := range bandTable {
+ if m&mask != 0 {
+ band |= bandCode
+ }
+ }
+ signalInfo.SetBand(band)
+ }
+ t.SendSignal(signalInfo)
}
// Register sets the file which will be monitored for IO events.
@@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() {
a.recipientTG = nil
a.recipientPG = nil
}
+
+// Signal returns which signal will be sent to the signal recipient.
+// A value of zero means the signal to deliver wasn't customized, which means
+// the default signal (SIGIO) will be delivered.
+func (a *FileAsync) Signal() linux.Signal {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.signal
+}
+
+// SetSignal overrides which signal to send when I/O is available.
+// The default behavior can be reset by specifying signal zero, which means
+// to send SIGIO.
+func (a *FileAsync) SetSignal(signal linux.Signal) error {
+ if signal != 0 && !signal.IsValid() {
+ return syserror.EINVAL
+ }
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.signal = signal
+ return nil
+}
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index 470d8bf83..f17f9c59c 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -121,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2
panic("VFS1 and VFS2 files set")
}
- slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+ slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
// Grow the table as required.
- if last := int32(len(slice)); fd >= last {
+ if last := int32(len(*slicePtr)); fd >= last {
end := fd + 1
if end < 2*last {
end = 2 * last
}
- slice = append(slice, make([]unsafe.Pointer, end-last)...)
- atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
+ newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...)
+ slicePtr = &newSlice
+ atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr))
}
+ slice := *slicePtr
+
var desc *descriptor
if file != nil || fileVFS2 != nil {
desc = &descriptor{
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 2cdcdfc1f..b8627a54f 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -214,9 +214,11 @@ type Kernel struct {
// netlinkPorts manages allocation of netlink socket port IDs.
netlinkPorts *port.Manager
- // saveErr is the error causing the sandbox to exit during save, if
- // any. It is protected by extMu.
- saveErr error `state:"nosave"`
+ // saveStatus is nil if the sandbox has not been saved, errSaved or
+ // errAutoSaved if it has been saved successfully, or the error causing the
+ // sandbox to exit during save.
+ // It is protected by extMu.
+ saveStatus error `state:"nosave"`
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
@@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager {
return k.netlinkPorts
}
-// SaveError returns the sandbox error that caused the kernel to exit during
-// save.
-func (k *Kernel) SaveError() error {
+var (
+ errSaved = errors.New("sandbox has been successfully saved")
+ errAutoSaved = errors.New("sandbox has been successfully auto-saved")
+)
+
+// SaveStatus returns the sandbox save status. If it was saved successfully,
+// autosaved indicates whether save was triggered by autosave. If it was not
+// saved successfully, err indicates the sandbox error that caused the kernel to
+// exit during save.
+func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ switch k.saveStatus {
+ case nil:
+ return false, false, nil
+ case errSaved:
+ return true, false, nil
+ case errAutoSaved:
+ return true, true, nil
+ default:
+ return false, false, k.saveStatus
+ }
+}
+
+// SetSaveSuccess sets the flag indicating that save completed successfully, if
+// no status was already set.
+func (k *Kernel) SetSaveSuccess(autosave bool) {
k.extMu.Lock()
defer k.extMu.Unlock()
- return k.saveErr
+ if k.saveStatus == nil {
+ if autosave {
+ k.saveStatus = errAutoSaved
+ } else {
+ k.saveStatus = errSaved
+ }
+ }
}
// SetSaveError sets the sandbox error that caused the kernel to exit during
@@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error {
func (k *Kernel) SetSaveError(err error) {
k.extMu.Lock()
defer k.extMu.Unlock()
- if k.saveErr == nil {
- k.saveErr = err
+ if k.saveStatus == nil {
+ k.saveStatus = err
}
}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 7b23cbe86..2d47d2e82 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -63,10 +63,19 @@ func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe {
// ReaderWriterPair returns read-only and write-only FDs for vp.
//
// Preconditions: statusFlags should not contain an open access mode.
-func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+func (vp *VFSPipe) ReaderWriterPair(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) {
// Connected pipes share the same locks.
locks := &vfs.FileLocks{}
- return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
+ r, err := vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks)
+ if err != nil {
+ return nil, nil, err
+ }
+ w, err := vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
+ if err != nil {
+ r.DecRef(ctx)
+ return nil, nil, err
+ }
+ return r, w, nil
}
// Allocate implements vfs.FileDescriptionImpl.Allocate.
@@ -85,7 +94,10 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
return nil, syserror.EINVAL
}
- fd := vp.newFD(mnt, vfsd, statusFlags, locks)
+ fd, err := vp.newFD(mnt, vfsd, statusFlags, locks)
+ if err != nil {
+ return nil, err
+ }
// Named pipes have special blocking semantics during open:
//
@@ -137,16 +149,18 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
}
// Preconditions: vp.mu must be held.
-func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription {
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
fd := &VFSPipeFD{
pipe: &vp.pipe,
}
fd.LockFD.Init(locks)
- fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ if err := fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
DenyPRead: true,
DenyPWrite: true,
UseDentryMetadata: true,
- })
+ }); err != nil {
+ return nil, err
+ }
switch {
case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
@@ -160,7 +174,7 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l
panic("invalid pipe flags: must be readable, writable, or both")
}
- return &fd.vfsfd
+ return &fd.vfsfd, nil
}
// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 1abfe2201..cef58a590 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) {
Signo: int32(linux.SIGTRAP),
Code: code,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
if t.beginPtraceStopLocked() {
tracer := t.Tracer()
tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index b99c0bffa..db01e4a97 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -29,17 +29,17 @@ import (
)
const (
- valueMax = 32767 // SEMVMX
+ // Maximum semaphore value.
+ valueMax = linux.SEMVMX
- // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL).
- semaphoresMax = 32000
+ // Maximum number of semaphore sets.
+ setsMax = linux.SEMMNI
- // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI).
- setsMax = 32000
+ // Maximum number of semaphroes in a semaphore set.
+ semsMax = linux.SEMMSL
- // semaphoresTotalMax is "system-wide limit on the number of semaphores"
- // (SEMMNS = SEMMNI*SEMMSL).
- semaphoresTotalMax = 1024000000
+ // Maximum number of semaphores in all semaphroe sets.
+ semsTotalMax = linux.SEMMNS
)
// Registry maintains a set of semaphores that can be found by key or ID.
@@ -52,6 +52,9 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
semaphores map[int32]*Set
lastIDUsed int32
+ // indexes maintains a mapping between a set's index in virtual array and
+ // its identifier.
+ indexes map[int32]int32
}
// Set represents a set of semaphores that can be operated atomically.
@@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
return &Registry{
userNS: userNS,
semaphores: make(map[int32]*Set),
+ indexes: make(map[int32]int32),
}
}
@@ -122,7 +126,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
// be found. If exclusive is true, it fails if a set with the same key already
// exists.
func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
- if nsems < 0 || nsems > semaphoresMax {
+ if nsems < 0 || nsems > semsMax {
return nil, syserror.EINVAL
}
@@ -163,10 +167,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
}
// Apply system limits.
+ //
+ // Map semaphores and map indexes in a registry are of the same size,
+ // check map semaphores only here for the system limit.
if len(r.semaphores) >= setsMax {
return nil, syserror.EINVAL
}
- if r.totalSems() > int(semaphoresTotalMax-nsems) {
+ if r.totalSems() > int(semsTotalMax-nsems) {
return nil, syserror.EINVAL
}
@@ -176,6 +183,53 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
return r.newSet(ctx, key, owner, owner, perms, nsems)
}
+// IPCInfo returns information about system-wide semaphore limits and parameters.
+func (r *Registry) IPCInfo() *linux.SemInfo {
+ return &linux.SemInfo{
+ SemMap: linux.SEMMAP,
+ SemMni: linux.SEMMNI,
+ SemMns: linux.SEMMNS,
+ SemMnu: linux.SEMMNU,
+ SemMsl: linux.SEMMSL,
+ SemOpm: linux.SEMOPM,
+ SemUme: linux.SEMUME,
+ SemUsz: linux.SEMUSZ,
+ SemVmx: linux.SEMVMX,
+ SemAem: linux.SEMAEM,
+ }
+}
+
+// SemInfo returns a seminfo structure containing the same information as
+// for IPC_INFO, except that SemUsz field returns the number of existing
+// semaphore sets, and SemAem field returns the number of existing semaphores.
+func (r *Registry) SemInfo() *linux.SemInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ info := r.IPCInfo()
+ info.SemUsz = uint32(len(r.semaphores))
+ info.SemAem = uint32(r.totalSems())
+
+ return info
+}
+
+// HighestIndex returns the index of the highest used entry in
+// the kernel's array.
+func (r *Registry) HighestIndex() int32 {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // By default, highest used index is 0 even though
+ // there is no semaphroe set.
+ var highestIndex int32
+ for index := range r.indexes {
+ if index > highestIndex {
+ highestIndex = index
+ }
+ }
+ return highestIndex
+}
+
// RemoveID removes set with give 'id' from the registry and marks the set as
// dead. All waiters will be awakened and fail.
func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
@@ -186,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
if set == nil {
return syserror.EINVAL
}
+ index, found := r.findIndexByID(id)
+ if !found {
+ // Inconsistent state.
+ panic(fmt.Sprintf("unable to find an index for ID: %d", id))
+ }
set.mu.Lock()
defer set.mu.Unlock()
@@ -197,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
}
delete(r.semaphores, set.ID)
+ delete(r.indexes, index)
set.destroy()
return nil
}
@@ -220,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File
continue
}
if r.semaphores[id] == nil {
+ index, found := r.findFirstAvailableIndex()
+ if !found {
+ panic("unable to find an available index")
+ }
+ r.indexes[index] = id
r.lastIDUsed = id
r.semaphores[id] = set
set.ID = id
@@ -238,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set {
return r.semaphores[id]
}
+// FindByIndex looks up a set given an index.
+func (r *Registry) FindByIndex(index int32) *Set {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ id, present := r.indexes[index]
+ if !present {
+ return nil
+ }
+ return r.semaphores[id]
+}
+
func (r *Registry) findByKey(key int32) *Set {
for _, v := range r.semaphores {
if v.key == key {
@@ -247,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set {
return nil
}
+func (r *Registry) findIndexByID(id int32) (int32, bool) {
+ for k, v := range r.indexes {
+ if v == id {
+ return k, true
+ }
+ }
+ return 0, false
+}
+
+func (r *Registry) findFirstAvailableIndex() (int32, bool) {
+ for index := int32(0); index < setsMax; index++ {
+ if _, present := r.indexes[index]; !present {
+ return index, true
+ }
+ }
+ return 0, false
+}
+
func (r *Registry) totalSems() int {
totalSems := 0
for _, v := range r.semaphores {
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index 80a592c8f..073e14507 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -6,6 +6,9 @@ package(licenses = ["notice"])
go_template_instance(
name = "shm_refs",
out = "shm_refs.go",
+ consts = {
+ "enableLogging": "true",
+ },
package = "shm",
prefix = "Shm",
template = "//pkg/refsvfs2:refs_template",
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index e8cce37d0..2488ae7d5 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 78f718cfe..884966120 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c5137c282..16986244c 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -368,8 +368,8 @@ func (t *Task) exitChildren() {
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- siginfo.SetPid(int32(c.tg.pidns.tids[t]))
- siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
+ siginfo.SetPID(int32(c.tg.pidns.tids[t]))
+ siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
c.tg.signalHandlers.mu.Lock()
c.sendSignalLocked(siginfo, true /* group */)
c.tg.signalHandlers.mu.Unlock()
@@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si
info := &arch.SignalInfo{
Signo: int32(sig),
}
- info.SetPid(int32(receiver.tg.pidns.tids[t]))
- info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.tids[t]))
+ info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
if t.exitStatus.Signaled() {
info.Code = arch.CLD_KILLED
info.SetStatus(int32(t.exitStatus.Signo))
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 42dd3e278..75af3af79 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) {
Signo: int32(linux.SIGCHLD),
Code: code,
}
- sigchld.SetPid(int32(t.tg.pidns.tids[target]))
- sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ sigchld.SetPID(int32(t.tg.pidns.tids[target]))
+ sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
sigchld.SetStatus(status)
// TODO(b/72102453): Set utime, stime.
t.sendSignalLocked(sigchld, true /* group */)
@@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState {
Signo: int32(sig),
Code: t.ptraceCode,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
} else {
t.ptraceCode = int32(sig)
t.ptraceSiginfo = nil
@@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
if parent == nil {
// Tracer has detached and t was created by Kernel.CreateProcess().
// Pretend the parent is in an ancestor PID + user namespace.
- info.SetPid(0)
- info.SetUid(int32(auth.OverflowUID))
+ info.SetPID(0)
+ info.SetUID(int32(auth.OverflowUID))
} else {
- info.SetPid(int32(t.tg.pidns.tids[parent]))
- info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(t.tg.pidns.tids[parent]))
+ info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
}
}
t.tg.signalHandlers.mu.Lock()
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 7fd77925f..49e21026e 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
// Translations must be contiguous and in increasing order of
// Translation.Source.
if i > 0 && ts[i-1].Source.End != t.Source.Start {
- return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t)
+ return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t)
}
// At least part of each Translation must be required.
if t.Source.Intersect(required).Length() == 0 {
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 4c8cd38ed..5ab2ef79f 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -36,12 +36,12 @@ type aioManager struct {
contexts map[uint64]*AIOContext
}
-func (a *aioManager) destroy() {
- a.mu.Lock()
- defer a.mu.Unlock()
+func (mm *MemoryManager) destroyAIOManager(ctx context.Context) {
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
- for _, ctx := range a.contexts {
- ctx.destroy()
+ for id := range mm.aioManager.contexts {
+ mm.destroyAIOContextLocked(ctx, id)
}
}
@@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
// be drained.
//
// Nil is returned if the context does not exist.
-func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
- a.mu.Lock()
- defer a.mu.Unlock()
- ctx, ok := a.contexts[id]
+//
+// Precondition: mm.aioManager.mu is locked.
+func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext {
+ aioCtx, ok := mm.aioManager.contexts[id]
if !ok {
return nil
}
- delete(a.contexts, id)
- ctx.destroy()
- return ctx
+
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
+
+ delete(mm.aioManager.contexts, id)
+ aioCtx.destroy()
+ return aioCtx
}
// lookupAIOContext looks up the given context.
@@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() {
}
}
-// Prepare reserves space for a new request, returning true if available.
-// Returns false if the context is busy.
-func (ctx *AIOContext) Prepare() bool {
+// Prepare reserves space for a new request, returning nil if available.
+// Returns EAGAIN if the context is busy and EINVAL if the context is dead.
+func (ctx *AIOContext) Prepare() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
+ if ctx.dead {
+ // Context died after the caller looked it up.
+ return syserror.EINVAL
+ }
if ctx.outstanding >= ctx.maxOutstanding {
- return false
+ // Context is busy.
+ return syserror.EAGAIN
}
ctx.outstanding++
- return true
+ return nil
}
// PopRequest pops a completed request if available, this function does not do
@@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
// DestroyAIOContext destroys an asynchronous I/O context. It returns the
// destroyed context. nil if the context does not exist.
func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
- if _, ok := mm.LookupAIOContext(ctx, id); !ok {
+ if !mm.isValidAddr(ctx, id) {
return nil
}
- // Only unmaps after it assured that the address is a valid aio context to
- // prevent random memory from been unmapped.
- //
- // Note: It's possible to unmap this address and map something else into
- // the same address. Then it would be unmapping memory that it doesn't own.
- // This is, however, the way Linux implements AIO. Keeps the same [weird]
- // semantics in case anyone relies on it.
- mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
-
- return mm.aioManager.destroyAIOContext(id)
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
+ return mm.destroyAIOContextLocked(ctx, id)
}
// LookupAIOContext looks up the given context. It returns false if the context
@@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC
return nil, false
}
- // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes
- // from id).
- var buf [4]byte
- _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
- if err != nil {
+ // Protect against 'id' that is inaccessible.
+ if !mm.isValidAddr(ctx, id) {
return nil, false
}
return aioCtx, true
}
+
+// isValidAddr determines if the address `id` is valid. (Linux also reads 4
+// bytes from id).
+func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool {
+ var buf [4]byte
+ _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
+ return err == nil
+}
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
index 3dabac1af..e8931922f 100644
--- a/pkg/sentry/mm/aio_context_state.go
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -15,6 +15,6 @@
package mm
// afterLoad is invoked by stateify.
-func (a *AIOContext) afterLoad() {
- a.requestReady = make(chan struct{}, 1)
+func (ctx *AIOContext) afterLoad() {
+ ctx.requestReady = make(chan struct{}, 1)
}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 09dbc06a4..120707429 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) {
panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users))
}
- mm.aioManager.destroy()
+ mm.destroyAIOManager(ctx)
mm.metadataMu.Lock()
exe := mm.executable
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index acac3d357..bc53bd41e 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) {
t.Errorf("CopyOut got %d want 1", n)
}
}
+
+// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be
+// prepared after destruction.
+func TestAIOPrepareAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ aioCtx, ok := mm.LookupAIOContext(ctx, id)
+ if !ok {
+ t.Fatalf("AIOContext not found")
+ }
+ mm.DestroyAIOContext(ctx, id)
+
+ // Prepare should fail because aioCtx should be destroyed.
+ if err := aioCtx.Prepare(); err != syserror.EINVAL {
+ t.Errorf("aioCtx.Prepare got err %v want nil", err)
+ } else if err == nil {
+ aioCtx.CancelPendingRequest()
+ }
+}
+
+// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be
+// looked up after memory manager is destroyed.
+func TestAIOLookupAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ mm.DecUsers(ctx)
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ mm.DecUsers(ctx) // This destroys the AIOContext manager.
+
+ if _, ok := mm.LookupAIOContext(ctx, id); ok {
+ t.Errorf("AIOContext found even after AIOContext manager is destroyed")
+ }
+}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 7c297fb9e..d99be7f46 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File
}
if f.opts.ManualZeroing {
- if err := f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
- }
- }); err != nil {
+ if err := f.manuallyZero(fr); err != nil {
return memmap.FileRange{}, err
}
}
@@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
panic(fmt.Sprintf("invalid range: %v", fr))
}
+ if f.opts.ManualZeroing {
+ // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in
+ // effect.
+ if err := f.manuallyZero(fr); err != nil {
+ return err
+ }
+ } else {
+ if err := f.decommitFile(fr); err != nil {
+ return err
+ }
+ }
+
+ f.markDecommitted(fr)
+ return nil
+}
+
+func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error {
+ return f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ })
+}
+
+func (f *MemoryFile) decommitFile(fr memmap.FileRange) error {
// "After a successful call, subsequent reads from this range will
// return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with
// FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2)
- err := syscall.Fallocate(
+ return syscall.Fallocate(
int(f.file.Fd()),
_FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE,
int64(fr.Start),
int64(fr.Length()))
- if err != nil {
- return err
- }
- f.markDecommitted(fr)
- return nil
}
func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
@@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() {
break
}
- if err := f.Decommit(fr); err != nil {
- log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
- // Zero the pages manually. This won't reduce memory usage, but at
- // least ensures that the pages will be zero when reallocated.
- f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
+ // If ManualZeroing is in effect, pages will be zeroed on allocation
+ // and may not be freed by decommitFile, so calling decommitFile is
+ // unnecessary.
+ if !f.opts.ManualZeroing {
+ if err := f.decommitFile(fr); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
+ // Zero the pages manually. This won't reduce memory usage, but at
+ // least ensures that the pages will be zero when reallocated.
+ if err := f.manuallyZero(fr); err != nil {
+ panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err))
}
- })
- // Pretend the pages were decommitted even though they weren't,
- // since the memory accounting implementation has no idea how to
- // deal with this.
- f.markDecommitted(fr)
+ }
}
+ f.markDecommitted(fr)
f.markReclaimed(fr)
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index acad4c793..f8ccb7430 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -91,6 +91,13 @@ func bluepillSigBus(c *vCPU) {
}
}
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ throw("run failed: ENOSYS")
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
//
//go:nosplit
@@ -126,3 +133,10 @@ func bluepillReadyStopGuest(c *vCPU) bool {
}
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ c.die(bluepillArchContext(context), "unknown")
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 965ad66b5..1f09813ba 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -42,6 +42,13 @@ var (
sErrEsr: _ESR_ELx_SERR_NMI,
},
}
+
+ // vcpuExtDabt is the event of ext_dabt.
+ vcpuExtDabt = kvmVcpuEvents{
+ exception: exception{
+ extDabtPending: 1,
+ },
+ }
)
// getTLS returns the value of TPIDR_EL0 register.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 9433d4da5..4d912769a 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -85,7 +85,7 @@ func bluepillStopGuest(c *vCPU) {
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
- throw("sErr injection failed")
+ throw("bounce sErr injection failed")
}
}
@@ -93,18 +93,54 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillSigBus(c *vCPU) {
+ // Host must support ARM64_HAS_RAS_EXTN.
if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
- throw("sErr injection failed")
+ if errno == syscall.EINVAL {
+ throw("No ARM64_HAS_RAS_EXTN feature in host.")
+ }
+ throw("nmi sErr injection failed")
}
}
+// bluepillExtDabt is reponsible for injecting external data abort.
+//
+//go:nosplit
+func bluepillExtDabt(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 {
+ throw("ext_dabt injection failed")
+ }
+}
+
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ bluepillExtDabt(c)
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection.
//
//go:nosplit
func bluepillReadyStopGuest(c *vCPU) bool {
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ switch c.runData.exitReason {
+ case _KVM_EXIT_ARM_NISV:
+ bluepillExtDabt(c)
+ default:
+ c.die(bluepillArchContext(context), "unknown")
+ }
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 75085ac6a..8c5369377 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -148,6 +148,9 @@ func bluepillHandler(context unsafe.Pointer) {
// mode and have interrupts disabled.
bluepillSigBus(c)
continue // Rerun vCPU.
+ case syscall.ENOSYS:
+ bluepillHandleEnosys(c)
+ continue
default:
throw("run failed")
}
@@ -220,7 +223,7 @@ func bluepillHandler(context unsafe.Pointer) {
c.die(bluepillArchContext(context), "entry failed")
return
default:
- c.die(bluepillArchContext(context), "unknown")
+ bluepillArchHandleExit(c, context)
return
}
}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 0b06a923a..9db1db4e9 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -47,10 +47,11 @@ type userRegs struct {
}
type exception struct {
- sErrPending uint8
- sErrHasEsr uint8
- pad [6]uint8
- sErrEsr uint64
+ sErrPending uint8
+ sErrHasEsr uint8
+ extDabtPending uint8
+ pad [5]uint8
+ sErrEsr uint64
}
type kvmVcpuEvents struct {
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 6abaa21c4..2492d57be 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -56,6 +56,7 @@ const (
_KVM_EXIT_FAIL_ENTRY = 0x9
_KVM_EXIT_INTERNAL_ERROR = 0x11
_KVM_EXIT_SYSTEM_EVENT = 0x18
+ _KVM_EXIT_ARM_NISV = 0x1c
)
// KVM capability options.
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 54837f20c..aa2d21748 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -54,7 +54,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
pageTable.Map(
usermem.Addr(ring0.KernelStartAddress|pr.virtual),
pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pagetables.MapOpts{AccessType: usermem.AnyAccess, Global: true},
pr.physical)
return true // Keep iterating.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index f2459755b..a466acf4d 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error {
}
// tcr_el1
- data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1
+ data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
reg.id = _KVM_ARM64_REGS_TCR_EL1
if err := c.setOneRegister(&reg); err != nil {
return err
@@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error {
c.SetTtbr0Kvm(uintptr(data))
// ttbr1_el1
- data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1)
+ data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0)
reg.id = _KVM_ARM64_REGS_TTBR1_EL1
if err := c.setOneRegister(&reg); err != nil {
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index f56aa3b79..571bfcc2e 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -18,8 +18,8 @@
//
// In a nutshell, it works as follows:
//
-// The creation of a new address space creates a new child processes with a
-// single thread which is traced by a single goroutine.
+// The creation of a new address space creates a new child process with a single
+// thread which is traced by a single goroutine.
//
// A context is just a collection of temporary variables. Calling Switch on a
// context does the following:
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 812ab80ef..aacd7ce70 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// facilitate vsyscall emulation. See patchSignalInfo.
patchSignalInfo(regs, &c.signalInfo)
return false
- } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) {
+ } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) {
// The signal was generated by this process. That means
// that it was an interrupt or something else that we
// should bail for. Note that we ignore signals
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 679b287c3..2852b7387 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "arch_genrule", "go_library")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
@@ -39,19 +39,19 @@ go_template_instance(
template = ":defs_arm64",
)
-genrule(
+arch_genrule(
name = "entry_impl_amd64",
srcs = ["entry_amd64.s"],
outs = ["entry_impl_amd64.s"],
- cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
-genrule(
+arch_genrule(
name = "entry_impl_arm64",
srcs = ["entry_arm64.s"],
outs = ["entry_impl_arm64.s"],
- cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
@@ -72,7 +72,6 @@ go_library(
"lib_amd64.s",
"lib_arm64.go",
"lib_arm64.s",
- "lib_arm64_unsafe.go",
"ring0.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 155f45ad8..b2bb18257 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -132,40 +132,6 @@
MOVD offset+PTRACE_R29(reg), R29; \
MOVD offset+PTRACE_R30(reg), R30;
-// NOP-s
-#define nop31Instructions() \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f;
-
#define ESR_ELx_EC_UNKNOWN (0x00)
#define ESR_ELx_EC_WFx (0x01)
/* Unallocated EC: 0x02 */
@@ -305,24 +271,20 @@
WORD $0xd538d092; //MRS TPIDR_EL1, R18
// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
-#define SWITCH_TO_APP_PAGETABLE(from) \
- MRS TTBR1_EL1, R0; \
- MOVD CPU_APP_ASID(from), R1; \
- BFI $48, R1, $16, R0; \
- MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set)
- ISB $15; \
- MOVD CPU_TTBR0_APP(from), RSV_REG; \
- MSR RSV_REG, TTBR0_EL1;
+#define SWITCH_TO_APP_PAGETABLE() \
+ MOVD CPU_APP_ASID(RSV_REG), RSV_REG_APP; \
+ MOVD CPU_TTBR0_APP(RSV_REG), RSV_REG; \
+ BFI $48, RSV_REG_APP, $16, RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1; \
+ ISB $15;
// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
-#define SWITCH_TO_KVM_PAGETABLE(from) \
- MRS TTBR1_EL1, R0; \
- MOVD $1, R1; \
- BFI $48, R1, $16, R0; \
- MSR R0, TTBR1_EL1; \
- ISB $15; \
- MOVD CPU_TTBR0_KVM(from), RSV_REG; \
- MSR RSV_REG, TTBR0_EL1;
+#define SWITCH_TO_KVM_PAGETABLE() \
+ MOVD CPU_TTBR0_KVM(RSV_REG), RSV_REG; \
+ MOVD $1, RSV_REG_APP; \
+ BFI $48, RSV_REG_APP, $16, RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1; \
+ ISB $15;
TEXT ·EnableVFP(SB),NOSPLIT,$0
MOVD $FPEN_ENABLE, R0
@@ -530,7 +492,7 @@ do_exit_to_el0:
WORD $0xd538d092 //MRS TPIDR_EL1, R18
- SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ SWITCH_TO_APP_PAGETABLE()
LDP 16*1(RSP), (R0, R1)
LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
@@ -555,10 +517,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
MOVD R1, RSP
- SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
+ SWITCH_TO_KVM_PAGETABLE()
MRS TPIDR_EL1, RSV_REG
- REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
ERET()
@@ -566,8 +528,16 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
// Start is the CPU entrypoint.
TEXT ·Start(SB),NOSPLIT,$0
// Init.
- MOVD $SCTLR_EL1_DEFAULT, R1
- MSR R1, SCTLR_EL1
+ WORD $0xd508871f // __tlbi(vmalle1)
+ DSB $7 // dsb(nsh)
+
+ MOVD $1<<12, R1 // Reset mdscr_el1 and disable
+ MSR R1, MDSCR_EL1 // access to the DCC from EL0
+ ISB $15
+
+ MRS TTBR1_EL1, R1
+ MSR R1, TTBR0_EL1
+ ISB $15
MOVD $CNTKCTL_EL1_DEFAULT, R1
MSR R1, CNTKCTL_EL1
@@ -576,6 +546,15 @@ TEXT ·Start(SB),NOSPLIT,$0
ORR $0xffff000000000000, RSV_REG, RSV_REG
WORD $0xd518d092 //MSR R18, TPIDR_EL1
+ // Init.
+ MOVD $SCTLR_EL1_DEFAULT, R1 // re-enable the mmu.
+ MSR R1, SCTLR_EL1
+ ISB $15
+ WORD $0xd508751f // ic iallu
+
+ DSB $7 // dsb(nsh)
+ ISB $15
+
B ·kernelExitToEl1(SB)
// El1_sync_invalid is the handler for an invalid EL1_sync.
@@ -748,79 +727,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
// Vectors implements exception vector table.
+// The start address of exception vector table should be 11-bits aligned.
+// For detail, please refer to arm developer document:
+// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table
+// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S
TEXT ·Vectors(SB),NOSPLIT,$0
+ PCALIGN $2048
B ·El1_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error_invalid(SB)
- nop31Instructions()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
- // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- WORD $0xd503201f //nop
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 9742308d8..a9703baf6 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,6 +24,9 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
+ # Use the libc malloc to avoid any extra dependencies. This is required to
+ # pass the sentry deps test.
+ system_malloc = True,
visibility = [
"//pkg/sentry/platform/kvm:__pkg__",
"//pkg/sentry/platform/ring0:__pkg__",
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index 90a7b8392..c05284641 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -53,11 +53,17 @@ func IsCanonical(addr uint64) bool {
return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000
}
+// SwitchToUser performs an eret.
+//
+// The return value is the exception vector.
+//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeAppASID(uintptr(switchOpts.UserASID))
if switchOpts.Flush {
- FlushTlbAll()
+ FlushTlbByASID(uintptr(switchOpts.UserASID))
}
regs := switchOpts.Registers
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index 0dffd33a3..a490bf3af 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -22,19 +22,25 @@ func storeAppASID(asid uintptr)
// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU.
func LocalFlushTlbAll()
-// FlushTlbAll flush all tlb.
+// FlushTlbByVA invalidates tlb by VA/Last-level/Inner-Shareable.
+func FlushTlbByVA(addr uintptr)
+
+// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable.
+func FlushTlbByASID(asid uintptr)
+
+// FlushTlbAll invalidates all tlb.
func FlushTlbAll()
// CPACREL1 returns the value of the CPACR_EL1 register.
func CPACREL1() (value uintptr)
-// FPCR returns the value of FPCR register.
+// GetFPCR returns the value of FPCR register.
func GetFPCR() (value uintptr)
// SetFPCR writes the FPCR value.
func SetFPCR(value uintptr)
-// FPSR returns the value of FPSR register.
+// GetFPSR returns the value of FPSR register.
func GetFPSR() (value uintptr)
// SetFPSR writes the FPSR value.
@@ -62,6 +68,4 @@ func DisableVFP()
// Init sets function pointers based on architectural features.
//
// This must be called prior to using ring0.
-func Init() {
- rewriteVectors()
-}
+func Init() {}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 6f4923539..e39b32841 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -15,6 +15,23 @@
#include "funcdata.h"
#include "textflag.h"
+#define TLBI_ASID_SHIFT 48
+
+TEXT ·FlushTlbByVA(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ DSB $10 // dsb(ishst)
+ WORD $0xd50883a1 // tlbi vale1is, x1
+ DSB $11 // dsb(ish)
+ RET
+
+TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8
+ MOVD asid+0(FP), R1
+ LSL $TLBI_ASID_SHIFT, R1, R1
+ DSB $10 // dsb(ishst)
+ WORD $0xd5088341 // tlbi aside1is, x1
+ DSB $11 // dsb(ish)
+ RET
+
TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0
DSB $6 // dsb(nshst)
WORD $0xd508871f // __tlbi(vmalle1)
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
deleted file mode 100644
index c05166fea..000000000
--- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build arm64
-
-package ring0
-
-import (
- "reflect"
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/safecopy"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-const (
- nopInstruction = 0xd503201f
- instSize = unsafe.Sizeof(uint32(0))
- vectorsRawLen = 0x800
-)
-
-func unsafeSlice(addr uintptr, length int) (slice []uint32) {
- hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
- hdr.Data = addr
- hdr.Len = length / int(instSize)
- hdr.Cap = length / int(instSize)
- return slice
-}
-
-// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
-//
-// According to the design documentation of Arm64,
-// the start address of exception vector table should be 11-bits aligned.
-// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
-// But, we can't align a function's start address to a specific address by using golang.
-// We have raised this question in golang community:
-// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
-// This function will be removed when golang supports this feature.
-//
-// There are 2 jobs were implemented in this function:
-// 1, move the start address of exception vector table into the specific address.
-// 2, modify the offset of each instruction.
-func rewriteVectors() {
- vectorsBegin := reflect.ValueOf(Vectors).Pointer()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // And the size is 0x800.
- // Please see the documentation as reference:
- // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
- //
- // But, golang does not allow to set a function's address to a specific value.
- // So, for gvisor, I defined the size of exception-vector-table as 4K,
- // filled the 2nd 2K part with NOP-s.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- //
- // So, the prerequisite for this function to work correctly is:
- // vectorsSafeLen >= 0x1000
- // vectorsRawLen = 0x800
- vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
- if vectorsSafeLen < 2*vectorsRawLen {
- panic("Can't update vectors")
- }
-
- vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
- vectorsRawLen32 := vectorsRawLen / int(instSize)
-
- offset := vectorsBegin & (1<<11 - 1)
- if offset != 0 {
- offset = 1<<11 - offset
- }
-
- pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
-
- _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-
- offset = offset / instSize // By index, not bytes.
- // Move exception-vector-table into the specific address, should uses memmove here.
- for i := 1; i <= vectorsRawLen32; i++ {
- vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
- }
-
- // Adjust branch since instruction was moved forward.
- for i := 0; i < vectorsRawLen32; i++ {
- if vectorsSafeTable[int(offset)+i] != nopInstruction {
- vectorsSafeTable[int(offset)+i] -= uint32(offset)
- }
- }
-
- _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-}
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index a3f775d15..cc1f6bfcc 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index ca16d0381..ebcc891b3 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -23,7 +23,20 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/tcpip",
"//pkg/usermem",
],
)
+
+go_test(
+ name = "control_test",
+ size = "small",
+ srcs = ["control_test.go"],
+ library = ":control",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/socket",
+ "//pkg/usermem",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..65b556489 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,18 +343,42 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
+ )
+}
+
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
+}
+
+// PackSockExtendedErr packs an IP*_RECVERR socket control message.
+func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ sockErr.CMsgLevel(),
+ sockErr.CMsgType(),
+ t.Arch().Width(),
+ sockErr,
)
}
@@ -384,7 +407,15 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
+ }
+
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
+ }
+
+ if cmsgs.IP.SockErr != nil {
+ buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf)
}
return buf
@@ -416,21 +447,23 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ if cmsgs.IP.SockErr != nil {
+ space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes())
+ }
+
+ return space
}
// Parse parses a raw socket control message into portable objects.
-func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) {
var (
cmsgs socket.ControlMessages
fds linux.ControlMessageRights
@@ -454,10 +487,6 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
i += linux.SizeOfControlMessageHeader
length := int(h.Length) - linux.SizeOfControlMessageHeader
- // The use of t.Arch().Width() is analogous to Linux's use of
- // sizeof(long) in CMSG_ALIGN.
- width := t.Arch().Width()
-
switch h.Level {
case linux.SOL_SOCKET:
switch h.Type {
@@ -489,6 +518,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.Unix.Credentials = scmCreds
i += binary.AlignUp(length, width)
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ var ts linux.Timeval
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &ts)
+ cmsgs.IP.Timestamp = ts.ToNsecCapped()
+ cmsgs.IP.HasTimestamp = true
+ i += binary.AlignUp(length, width)
+
default:
// Unknown message type.
return socket.ControlMessages{}, syserror.EINVAL
@@ -512,7 +551,26 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
i += binary.AlignUp(length, width)
default:
@@ -528,6 +586,25 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += binary.AlignUp(length, width)
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
+ case linux.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go
new file mode 100644
index 000000000..d40a4cc85
--- /dev/null
+++ b/pkg/sentry/socket/control/control_test.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control provides internal representations of socket control
+// messages.
+package control
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestParse(t *testing.T) {
+ // Craft the control message to parse.
+ length := linux.SizeOfControlMessageHeader + linux.SizeOfTimeval
+ hdr := linux.ControlMessageHeader{
+ Length: uint64(length),
+ Level: linux.SOL_SOCKET,
+ Type: linux.SO_TIMESTAMP,
+ }
+ buf := make([]byte, 0, length)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &hdr)
+ ts := linux.Timeval{
+ Sec: 2401,
+ Usec: 343,
+ }
+ buf = binary.Marshal(buf, usermem.ByteOrder, &ts)
+
+ cmsg, err := Parse(nil, nil, buf, 8 /* width */)
+ if err != nil {
+ t.Fatalf("Parse(_, _, %+v, _): %v", cmsg, err)
+ }
+
+ want := socket.ControlMessages{
+ IP: socket.IPControlMessages{
+ HasTimestamp: true,
+ Timestamp: ts.ToNsecCapped(),
+ },
+ }
+ if diff := cmp.Diff(want, cmsg); diff != "" {
+ t.Errorf("unexpected message parsed, (-want, +got):\n%s", diff)
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..5b868216d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
case linux.SO_LINGER:
optlen = syscall.SizeofLinger
@@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
case linux.IP_PKTINFO:
optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
}
case linux.SOL_TCP:
switch name {
- case linux.TCP_NODELAY:
+ case linux.TCP_NODELAY, linux.TCP_INQ:
optlen = sizeofInt32
}
}
@@ -416,68 +416,76 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
return nil
}
-// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- // Only allow known and safe flags.
- //
- // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
- // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
- // Socket interface's dependence on netstack.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
- }
+func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
- var senderAddr linux.SockAddr
+ msg := syscall.Msghdr{}
+ if len(iovs) > 0 {
+ msg.Iov = &iovs[0]
+ msg.Iovlen = uint64(len(iovs))
+ }
var senderAddrBuf []byte
if senderRequested {
senderAddrBuf = make([]byte, sizeofSockaddr)
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(sizeofSockaddr)
}
-
var controlBuf []byte
- var msgFlags int
-
- recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
- // Refuse to do anything if any part of dst.Addrs was unusable.
- if uint64(dst.NumBytes()) != dsts.NumBytes() {
- return 0, nil
- }
- if dsts.IsEmpty() {
- return 0, nil
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
}
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
+ }
+ return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
+}
- // We always do a non-blocking recv*().
- sysflags := flags | syscall.MSG_DONTWAIT
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ // Only allow known and safe flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ }
- iovs := safemem.IovecsFromBlockSeq(dsts)
- msg := syscall.Msghdr{
- Iov: &iovs[0],
- Iovlen: uint64(len(iovs)),
- }
- if len(senderAddrBuf) != 0 {
- msg.Name = &senderAddrBuf[0]
- msg.Namelen = uint32(len(senderAddrBuf))
- }
- if controlLen > 0 {
- if controlLen > maxControlLen {
- controlLen = maxControlLen
+ var senderAddrBuf []byte
+ var controlBuf []byte
+ var msgFlags int
+ copyToDst := func() (int64, error) {
+ var n uint64
+ var err error
+ if dst.NumBytes() == 0 {
+ // We want to make the recvmsg(2) call to the host even if dst is empty
+ // to fetch control messages, sender address or errors if any occur.
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
+ return int64(n), err
+ }
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
}
- controlBuf = make([]byte, controlLen)
- msg.Control = &controlBuf[0]
- msg.Controllen = controlLen
- }
- n, err := recvmsg(s.fd, &msg, sysflags)
- if err != nil {
- return 0, err
- }
- senderAddrBuf = senderAddrBuf[:msg.Namelen]
- msgFlags = int(msg.Flags)
- controlLen = uint64(msg.Controllen)
- return n, nil
- })
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
+ return n, err
+ })
+ return dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
var ch chan struct{}
- n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
- if flags&syscall.MSG_DONTWAIT == 0 {
+ n, err := copyToDst()
+ // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
+ if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 {
for err == syserror.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
// case it can't have returned any data.
@@ -494,48 +502,85 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
s.EventRegister(&e, waiter.EventIn)
defer s.EventUnregister(&e)
}
- n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err = copyToDst()
}
}
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ var senderAddr linux.SockAddr
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
+}
+func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
controlMessages := socket.ControlMessages{}
for _, unixCmsg := range unixControlMessages {
switch unixCmsg.Header.Level {
- case syscall.SOL_IP:
+ case linux.SOL_SOCKET:
switch unixCmsg.Header.Type {
- case syscall.IP_TOS:
+ case linux.SO_TIMESTAMP:
+ controlMessages.IP.HasTimestamp = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp)
+ }
+
+ case linux.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case linux.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
- case syscall.IP_PKTINFO:
+ case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ controlMessages.IP.PacketInfo = packetInfo
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
}
- case syscall.SOL_IPV6:
+ case linux.SOL_IPV6:
switch unixCmsg.Header.Type {
- case syscall.IPV6_TCLASS:
+ case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
+ }
+
+ case linux.SOL_TCP:
+ switch unixCmsg.Header.Type {
+ case linux.TCP_INQ:
+ controlMessages.IP.HasInq = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq)
}
}
}
-
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+ return controlMessages
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index fae3b6783..b2206900b 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -25,7 +25,6 @@ go_library(
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/metric",
- "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e8a0103bf..dcf898c0a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -28,9 +28,9 @@ import (
"bytes"
"fmt"
"io"
+ "io/ioutil"
"math"
"reflect"
- "sync/atomic"
"syscall"
"time"
@@ -43,7 +43,6 @@ import (
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/metric"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -84,69 +83,95 @@ var Metrics = tcpip.Stats{
MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
ICMP: tcpip.ICMPStats{
- V4PacketsSent: tcpip.ICMPv4SentPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ V4: tcpip.ICMPv4Stats{
+ PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
},
- V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ V6: tcpip.ICMPv6Stats{
+ PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- V6PacketsSent: tcpip.ICMPv6SentPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ IGMP: tcpip.IGMPStats{
+ PacketsSent: tcpip.IGMPSentPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."),
},
- V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ PacketsReceived: tcpip.IGMPReceivedPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."),
+ ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."),
},
},
IP: tcpip.IPStats{
@@ -209,18 +234,6 @@ const sizeOfInt32 int = 4
var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
-// ntohs converts a 16-bit number from network byte order to host byte order. It
-// assumes that the host is little endian.
-func ntohs(v uint16) uint16 {
- return v<<8 | v>>8
-}
-
-// htons converts a 16-bit number from host byte order to network byte order. It
-// assumes that the host is little endian.
-func htons(v uint16) uint16 {
- return ntohs(v)
-}
-
// commonEndpoint represents the intersection of a tcpip.Endpoint and a
// transport.Endpoint.
type commonEndpoint interface {
@@ -240,10 +253,6 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
- // transport.Endpoint.SetSockOptBool.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -252,14 +261,14 @@ type commonEndpoint interface {
// transport.Endpoint.GetSockOpt.
GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
- // transport.Endpoint.GetSockOpt.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
// LastError implements tcpip.Endpoint.LastError and
// transport.Endpoint.LastError.
LastError() *tcpip.Error
@@ -298,19 +307,11 @@ type socketOpsCommon struct {
skType linux.SockType
protocol int
- // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
- // Must be accessed using atomic operations. It must only be written
- // with readMu held but can be read without holding readMu. The latter
- // is required to avoid deadlocks in epoll Readiness checks.
- readViewHasData uint32
-
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
- // readView contains the remaining payload from the last packet.
- readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -326,17 +327,15 @@ type socketOpsCommon struct {
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
- // sockOptInq corresponds to TCP_INQ. It is implemented at this level
- // because it takes into account data from readView.
+ // TODO(b/153685824): Move this to SocketOptions.
+ // sockOptInq corresponds to TCP_INQ.
sockOptInq bool
}
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
dirent := socket.NewDirent(t, netstackDevice)
@@ -365,127 +364,27 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
-// AF_INET6, and AF_PACKET addresses.
-//
-// AddressAndFamily returns an address and its family.
-func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
- // Make sure we have at least 2 bytes for the address family.
- if len(addr) < 2 {
- return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
- }
-
- // Get the rest of the fields based on the address family.
- switch family := usermem.ByteOrder.Uint16(addr); family {
- case linux.AF_UNIX:
- path := addr[2:]
- if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- // Drop the terminating NUL (if one exists) and everything after
- // it for filesystem (non-abstract) addresses.
- if len(path) > 0 && path[0] != 0 {
- if n := bytes.IndexByte(path[1:], 0); n >= 0 {
- path = path[:n+1]
- }
- }
- return tcpip.FullAddress{
- Addr: tcpip.Address(path),
- }, family, nil
-
- case linux.AF_INET:
- var a linux.SockAddrInet
- if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- return out, family, nil
-
- case linux.AF_INET6:
- var a linux.SockAddrInet6
- if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- if isLinkLocal(out.Addr) {
- out.NIC = tcpip.NICID(a.Scope_id)
- }
- return out, family, nil
-
- case linux.AF_PACKET:
- var a linux.SockAddrLink
- if len(addr) < sockAddrLinkSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
- if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
-
- // TODO(gvisor.dev/issue/173): Return protocol too.
- return tcpip.FullAddress{
- NIC: tcpip.NICID(a.InterfaceIndex),
- Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
- }, family, nil
-
- case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, family, nil
-
- default:
- return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
- }
-}
-
func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
-// fetchReadView updates the readView field of the socket if it's currently
-// empty. It assumes that the socket is locked.
-//
// Precondition: s.readMu must be held.
-func (s *socketOpsCommon) fetchReadView() *syserr.Error {
- if len(s.readView) > 0 {
- return nil
- }
- s.readView = nil
- s.sender = tcpip.FullAddress{}
- s.linkPacketInfo = tcpip.LinkPacketInfo{}
+func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
+ res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
+ Peek: peek,
+ NeedRemoteAddr: true,
+ NeedLinkPacketInfo: true,
+ })
- var v buffer.View
- var cms tcpip.ControlMessages
- var err *tcpip.Error
+ // Assign these anyways.
+ s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
+ s.sender = res.RemoteAddr
+ s.linkPacketInfo = res.LinkPacketInfo
- switch e := s.Endpoint.(type) {
- // The ordering of these interfaces matters. The most specific
- // interfaces must be specified before the more generic Endpoint
- // interface.
- case tcpip.PacketEndpoint:
- v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
- case tcpip.Endpoint:
- v, cms, err = e.Read(&s.sender)
- }
if err != nil {
- atomic.StoreUint32(&s.readViewHasData, 0)
- return syserr.TranslateNetstackError(err)
+ return 0, 0, syserr.TranslateNetstackError(err)
}
-
- s.readView = v
- s.readCM = cms
- atomic.StoreUint32(&s.readViewHasData, 1)
-
- return nil
+ return res.Count, res.Total, nil
}
// Release implements fs.FileOperations.Release.
@@ -502,11 +401,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
return
}
- var v tcpip.LingerOption
- if err := s.Endpoint.GetSockOpt(&v); err != nil {
- return
- }
-
+ v := s.Endpoint.SocketOptions().GetLinger()
// The case for zero timeout is handled in tcp endpoint close function.
// Close is blocked until either:
// 1. The endpoint state is not in any of the states: FIN-WAIT1,
@@ -538,38 +433,14 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// WriteTo implements fs.FileOperations.WriteTo.
func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
+ defer s.readMu.Unlock()
- // Copy as much data as possible.
- done := int64(0)
- for count > 0 {
- // This may return a blocking error.
- if err := s.fetchReadView(); err != nil {
- s.readMu.Unlock()
- return done, err.ToError()
- }
-
- // Write to the underlying file.
- n, err := dst.Write(s.readView)
- done += int64(n)
- count -= int64(n)
- if dup {
- // That's all we support for dup. This is generally
- // supported by any Linux system calls, but the
- // expectation is that now a caller will call read to
- // actually remove these bytes from the socket.
- break
- }
-
- // Drop that part of the view.
- s.readView.TrimFront(n)
- if err != nil {
- s.readMu.Unlock()
- return done, err
- }
+ // This may return a blocking error.
+ n, _, err := s.readLocked(dst, int(count), dup /* peek */)
+ if err != nil {
+ return 0, err.ToError()
}
-
- s.readMu.Unlock()
- return done, nil
+ return int64(n), nil
}
// ioSequencePayload implements tcpip.Payload.
@@ -705,17 +576,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
// Readiness returns a mask of ready events for socket s.
func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
- r := s.Endpoint.Readiness(mask)
-
- // Check our cached value iff the caller asked for readability and the
- // endpoint itself is currently not readable.
- if (mask & ^r & waiter.EventIn) != 0 {
- if atomic.LoadUint32(&s.readViewHasData) == 1 {
- r |= waiter.EventIn
- }
- }
-
- return r
+ return s.Endpoint.Readiness(mask)
}
func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
@@ -723,11 +584,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
return nil
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
- v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return syserr.TranslateNetstackError(err)
- }
- if !v {
+ if !s.Endpoint.SocketOptions().GetV6Only() {
return nil
}
}
@@ -751,7 +608,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -832,7 +689,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
} else {
var err *syserr.Error
- addr, family, err = AddressAndFamily(sockaddr)
+ addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -923,7 +780,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -1007,7 +864,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
case linux.SOL_TCP:
- return getSockOptTCP(t, ep, name, outLen)
+ return getSockOptTCP(t, s, ep, name, outLen)
case linux.SOL_IPV6:
return getSockOptIPv6(t, s, ep, name, outPtr, outLen)
@@ -1043,7 +900,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.LastError()
+ err := ep.SocketOptions().GetLastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1124,10 +981,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return &v, nil
case linux.SO_BINDTODEVICE:
- var v tcpip.BindToDeviceOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetBindToDevice()
if v == 0 {
var b primitive.ByteSlice
return &b, nil
@@ -1170,11 +1024,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.LingerOption
var linger linux.Linger
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetLinger()
if v.Enabled {
linger.OnOff = 1
@@ -1205,13 +1056,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.OutOfBandInlineOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(v)
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline()))
+ return &v, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1226,8 +1072,13 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn()))
- return &v, nil
+ // This option is only viable for TCP endpoints.
+ var v bool
+ if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) {
+ v = tcp.EndpointState(ep.State()) == tcp.StateListen
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1236,46 +1087,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(!v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption()))
+ return &v, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.CorkOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption()))
+ return &v, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck()))
+ return &v, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1449,19 +1290,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, _ := s.Type()
+ if family != linux.AF_INET6 {
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only()))
+ return &v, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1493,13 +1339,23 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
+ return &v, nil
+ case linux.IPV6_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1511,7 +1367,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
@@ -1520,7 +1376,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1540,7 +1396,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1560,7 +1416,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1582,6 +1438,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
// getSockOptIP implements GetSockOpt when level is SOL_IP.
func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1624,7 +1485,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
return &a.(*linux.SockAddrInet).Addr, nil
@@ -1633,13 +1494,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop()))
+ return &v, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
@@ -1663,26 +1519,40 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
+ return &v, nil
+
+ case linux.IP_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo()))
+ return &v, nil
+
+ case linux.IP_HDRINCL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
+ return &v, nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
@@ -1694,7 +1564,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet), nil
case linux.IPT_SO_GET_INFO:
@@ -1801,7 +1671,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
return setSockOptSocket(t, s, ep, name, optVal)
case linux.SOL_TCP:
- return setSockOptTCP(t, ep, name, optVal)
+ return setSockOptTCP(t, s, ep, name, optVal)
case linux.SOL_IPV6:
return setSockOptIPv6(t, s, ep, name, optVal)
@@ -1870,8 +1740,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
name := string(optVal[:n])
if name == "" {
- v := tcpip.BindToDeviceOption(0)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0))
}
s := t.NetworkContext()
if s == nil {
@@ -1879,8 +1748,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
for nicID, nic := range s.Interfaces() {
if nic.Name == name {
- v := tcpip.BindToDeviceOption(nicID)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID))
}
}
return syserr.ErrUnknownDevice
@@ -1949,8 +1817,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- opt := tcpip.OutOfBandInlineOption(v)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
+ ep.SocketOptions().SetOutOfBandInline(v != 0)
+ return nil
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1973,10 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(
- ep.SetSockOpt(&tcpip.LingerOption{
- Enabled: v.OnOff != 0,
- Timeout: time.Second * time.Duration(v.Linger)}))
+ ep.SocketOptions().SetLinger(tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger),
+ })
+ return nil
case linux.SO_DETACH_FILTER:
// optval is ignored.
@@ -1991,7 +1860,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
-func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if len(optVal) < sizeOfInt32 {
@@ -1999,7 +1873,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
+ ep.SocketOptions().SetDelayOption(v == 0)
+ return nil
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -2007,7 +1882,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
+ ep.SocketOptions().SetCorkOption(v != 0)
+ return nil
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -2015,7 +1891,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
+ ep.SocketOptions().SetQuickAck(v != 0)
+ return nil
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -2127,18 +2004,55 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, skProto := s.Type()
+ if family != linux.AF_INET6 {
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
+ if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ }
+
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+ ep.SocketOptions().SetV6Only(v != 0)
+ return nil
+
+ case linux.IPV6_ADD_MEMBERSHIP:
+ req, err := copyInMulticastV6Request(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
- case linux.IPV6_ADD_MEMBERSHIP,
- linux.IPV6_DROP_MEMBERSHIP,
- linux.IPV6_IPSEC_POLICY,
+ case linux.IPV6_DROP_MEMBERSHIP:
+ req, err := copyInMulticastV6Request(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
// TODO(b/148887420): Add support for IPV6_PKTINFO.
@@ -2154,6 +2068,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2173,7 +2096,18 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ ep.SocketOptions().SetReceiveTClass(v != 0)
+ return nil
+ case linux.IPV6_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2181,7 +2115,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return syserr.ErrProtocolNotAvailable
}
@@ -2206,6 +2140,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
var (
inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
+ inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{}))
)
// copyInMulticastRequest copies in a variable-size multicast request. The
@@ -2239,6 +2174,16 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
return req, nil
}
+func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syserr.Error) {
+ if len(optVal) < inet6MulticastRequestSize {
+ return linux.Inet6MulticastRequest{}, syserr.ErrInvalidArgument
+ }
+
+ var req linux.Inet6MulticastRequest
+ binary.Unmarshal(optVal[:inet6MulticastRequestSize], usermem.ByteOrder, &req)
+ return req, nil
+}
+
// parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf.
//
// net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options.
@@ -2256,6 +2201,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
// setSockOptIP implements SetSockOpt when level is SOL_IP.
func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_MULTICAST_TTL:
v, err := parseIntOrChar(optVal)
@@ -2308,7 +2258,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{
NIC: tcpip.NICID(req.InterfaceIndex),
- InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]),
}))
case linux.IP_MULTICAST_LOOP:
@@ -2317,7 +2267,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+ ep.SocketOptions().SetMulticastLoop(v != 0)
+ return nil
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -2353,7 +2304,19 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ ep.SocketOptions().SetReceiveTOS(v != 0)
+ return nil
+
+ case linux.IP_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
case linux.IP_PKTINFO:
if len(optVal) == 0 {
@@ -2363,7 +2326,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ ep.SocketOptions().SetReceivePacketInfo(v != 0)
+ return nil
case linux.IP_HDRINCL:
if len(optVal) == 0 {
@@ -2373,7 +2337,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+ ep.SocketOptions().SetHeaderIncluded(v != 0)
+ return nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
@@ -2410,10 +2387,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2487,11 +2462,9 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_MULTICAST_IF,
linux.IPV6_MULTICAST_LOOP,
linux.IPV6_RECVDSTOPTS,
- linux.IPV6_RECVERR,
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2515,7 +2488,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
switch name {
case linux.IP_TOS,
linux.IP_TTL,
- linux.IP_HDRINCL,
linux.IP_OPTIONS,
linux.IP_ROUTER_ALERT,
linux.IP_RECVOPTS,
@@ -2523,7 +2495,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
linux.IP_PKTINFO,
linux.IP_PKTOPTIONS,
linux.IP_MTU_DISCOVER,
- linux.IP_RECVERR,
linux.IP_RECVTTL,
linux.IP_RECVTOS,
linux.IP_MTU,
@@ -2562,72 +2533,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
}
}
-// isLinkLocal determines if the given IPv6 address is link-local. This is the
-// case when it has the fe80::/10 prefix. This check is used to determine when
-// the NICID is relevant for a given IPv6 address.
-func isLinkLocal(addr tcpip.Address) bool {
- return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
-}
-
-// ConvertAddress converts the given address to a native format.
-func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
- switch family {
- case linux.AF_UNIX:
- var out linux.SockAddrUnix
- out.Family = linux.AF_UNIX
- l := len([]byte(addr.Addr))
- for i := 0; i < l; i++ {
- out.Path[i] = int8(addr.Addr[i])
- }
-
- // Linux returns the used length of the address struct (including the
- // null terminator) for filesystem paths. The Family field is 2 bytes.
- // It is sometimes allowed to exclude the null terminator if the
- // address length is the max. Abstract and empty paths always return
- // the full exact length.
- if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
- return &out, uint32(2 + l)
- }
- return &out, uint32(3 + l)
-
- case linux.AF_INET:
- var out linux.SockAddrInet
- copy(out.Addr[:], addr.Addr)
- out.Family = linux.AF_INET
- out.Port = htons(addr.Port)
- return &out, uint32(sockAddrInetSize)
-
- case linux.AF_INET6:
- var out linux.SockAddrInet6
- if len(addr.Addr) == header.IPv4AddressSize {
- // Copy address in v4-mapped format.
- copy(out.Addr[12:], addr.Addr)
- out.Addr[10] = 0xff
- out.Addr[11] = 0xff
- } else {
- copy(out.Addr[:], addr.Addr)
- }
- out.Family = linux.AF_INET6
- out.Port = htons(addr.Port)
- if isLinkLocal(addr.Addr) {
- out.Scope_id = uint32(addr.NIC)
- }
- return &out, uint32(sockAddrInet6Size)
-
- case linux.AF_PACKET:
- // TODO(gvisor.dev/issue/173): Return protocol too.
- var out linux.SockAddrLink
- out.Family = linux.AF_PACKET
- out.InterfaceIndex = int32(addr.NIC)
- out.HardwareAddrLen = header.EthernetAddressSize
- copy(out.HardwareAddr[:], addr.Addr)
- return &out, uint32(sockAddrLinkSize)
-
- default:
- return nil, 0
- }
-}
-
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
@@ -2636,7 +2541,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2648,70 +2553,24 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
-// coalescingRead is the fast path for non-blocking, non-peek, stream-based
-// case. It coalesces as many packets as possible before returning to the
-// caller.
+// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
//
// Precondition: s.readMu must be locked.
-func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
- var err *syserr.Error
- var copied int
-
- // Copy as many views as possible into the user-provided buffer.
- for {
- // Always do at least one fetchReadView, even if the number of bytes to
- // read is 0.
- err = s.fetchReadView()
- if err != nil || len(s.readView) == 0 {
- break
- }
- if dst.NumBytes() == 0 {
- break
- }
-
- var n int
- var e error
- if discard {
- n = len(s.readView)
- if int64(n) > dst.NumBytes() {
- n = int(dst.NumBytes())
- }
- } else {
- n, e = dst.CopyOut(ctx, s.readView)
- // Set the control message, even if 0 bytes were read.
- if e == nil {
- s.updateTimestamp()
- }
- }
- copied += n
- s.readView.TrimFront(n)
-
- dst = dst.DropFirst(n)
- if e != nil {
- err = syserr.FromError(e)
- break
- }
- // If we are done reading requested data then stop.
- if dst.NumBytes() == 0 {
- break
- }
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
+ // Always do at least one read, even if the number of bytes to read is 0.
+ var n int
+ n, _, err := s.readLocked(dst, count, false /* peek */)
+ if err != nil {
+ return 0, err
}
-
- // If we managed to copy something, we must deliver it.
- if copied > 0 {
- s.Endpoint.ModerateRecvBuf(copied)
- return copied, nil
+ if n > 0 {
+ s.Endpoint.ModerateRecvBuf(n)
}
-
- return 0, err
+ return n, nil
}
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
@@ -2723,7 +2582,7 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
return
}
cmsg.IP.HasInq = true
- cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+ cmsg.IP.Inq = int32(rcvBufUsed)
}
func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
@@ -2760,7 +2619,21 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
s.readMu.Lock()
- n, err := s.coalescingRead(ctx, dst, trunc)
+
+ var w io.Writer
+ if trunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
+
+ if err == nil && !trunc {
+ // Set the control message, even if 0 bytes were read.
+ s.updateTimestamp()
+ }
+
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
s.readMu.Unlock()
@@ -2770,18 +2643,32 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
s.readMu.Lock()
defer s.readMu.Unlock()
- if err := s.fetchReadView(); err != nil {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read, and does not write to buffer.
+ isTCPPeekTrunc := !isPacket && peek && trunc
+
+ var w io.Writer
+ if isTCPPeekTrunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ var numRead, numTotal int
+ var err *syserr.Error
+ numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
+ if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, err
}
- if !isPacket && peek && trunc {
- // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
- // amount that could be read.
+ if isTCPPeekTrunc {
+ // TCP endpoint does not return the total bytes in buffer as numTotal.
+ // We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
- available := len(s.readView) + int(rql)
+ available := int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
return available, 0, nil, 0, socket.ControlMessages{}, nil
@@ -2789,88 +2676,65 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
- n, err := dst.CopyOut(ctx, s.readView)
// Set the control message, even if 0 bytes were read.
- if err == nil {
- s.updateTimestamp()
- }
+ s.updateTimestamp()
+
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
- addr, addrLen = ConvertAddress(s.family, s.sender)
+ addr, addrLen = socket.ConvertAddress(s.family, s.sender)
switch v := addr.(type) {
case *linux.SockAddrLink:
- v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
if peek {
- if l := len(s.readView); trunc && l > n {
+ if trunc && numTotal > numRead {
// isPacket must be true.
- return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
}
-
- if isPacket || err != nil {
- return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
- }
-
- // We need to peek beyond the first message.
- dst = dst.DropFirst(n)
- num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
- // TODO(b/78348848): Handle peek timestamp.
- if err != nil {
- return int64(n), syserr.TranslateNetstackError(err).ToError()
- }
- return int64(n), nil
- }})
- n += int(num)
- if err == syserror.ErrWouldBlock && n > 0 {
- // We got some data, so no need to return an error.
- err = nil
- }
- return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
+ return numRead, 0, nil, 0, s.controlMessages(), nil
}
var msgLen int
if isPacket {
- msgLen = len(s.readView)
- s.readView = nil
+ msgLen = numTotal
} else {
- msgLen = int(n)
- s.readView.TrimFront(int(n))
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+ msgLen = numRead
}
var flags int
- if msgLen > int(n) {
+ if msgLen > numRead {
flags |= linux.MSG_TRUNC
}
+ n := numRead
if trunc {
n = msgLen
}
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
- return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
+ return n, flags, addr, addrLen, cmsg, nil
}
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasInq: s.readCM.HasInq,
+ Inq: s.readCM.Inq,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
+ SockErr: s.readCM.SockErr,
},
}
}
@@ -2887,9 +2751,66 @@ func (s *socketOpsCommon) updateTimestamp() {
}
}
+// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb().
+func (s *socketOpsCommon) dequeueErr() *tcpip.SockError {
+ so := s.Endpoint.SocketOptions()
+ err := so.DequeueErr()
+ if err == nil {
+ return nil
+ }
+
+ // Update socket error to reflect ICMP errors in queue.
+ if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nextErr.Err)
+ } else if err.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nil)
+ }
+ return err
+}
+
+// addrFamilyFromNetProto returns the address family identifier for the given
+// network protocol.
+func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int {
+ switch net {
+ case header.IPv4ProtocolNumber:
+ return linux.AF_INET
+ case header.IPv6ProtocolNumber:
+ return linux.AF_INET6
+ default:
+ panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net))
+ }
+}
+
+// recvErr handles MSG_ERRQUEUE for recvmsg(2).
+// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error().
+func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ sockErr := s.dequeueErr()
+ if sockErr == nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ // The payload of the original packet that caused the error is passed as
+ // normal data via msg_iovec. -- recvmsg(2)
+ msgFlags := linux.MSG_ERRQUEUE
+ if int(dst.NumBytes()) < len(sockErr.Payload) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ n, err := dst.CopyOut(t, sockErr.Payload)
+
+ // The original destination address of the datagram that caused the error is
+ // supplied via msg_name. -- recvmsg(2)
+ dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst)
+ cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})}
+ return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err)
+}
+
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return s.recvErr(t, dst)
+ }
+
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -2965,7 +2886,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, family, err := AddressAndFamily(to)
+ addrBuf, family, err := socket.AddressAndFamily(to)
if err != nil {
return 0, err
}
@@ -3063,11 +2984,6 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
return 0, syserr.TranslateNetstackError(terr).ToError()
}
- // Add bytes removed from the endpoint but not yet sent to the caller.
- s.readMu.Lock()
- v += len(s.readView)
- s.readMu.Unlock()
-
if v > math.MaxInt32 {
v = math.MaxInt32
}
@@ -3384,6 +3300,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
return rv
}
+func isTCPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP)
+}
+
+func isUDPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP)
+}
+
+func isICMPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6)
+}
+
// State implements socket.Socket.State. State translates the internal state
// returned by netstack to values defined by Linux.
func (s *socketOpsCommon) State() uint32 {
@@ -3393,7 +3321,7 @@ func (s *socketOpsCommon) State() uint32 {
}
switch {
- case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ case isTCPSocket(s.skType, s.protocol):
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -3422,7 +3350,7 @@ func (s *socketOpsCommon) State() uint32 {
// Internal or unknown state.
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ case isUDPSocket(s.skType, s.protocol):
// UDP socket.
switch udp.EndpointState(s.Endpoint.State()) {
case udp.StateInitial, udp.StateBound, udp.StateClosed:
@@ -3432,7 +3360,7 @@ func (s *socketOpsCommon) State() uint32 {
default:
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ case isICMPSocket(s.skType, s.protocol):
// TODO(b/112063468): Export states for ICMP sockets.
case s.skType == linux.SOCK_RAW:
// TODO(b/112063468): Export states for raw sockets.
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index b0d9e4d9e..b756bfca0 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// NewVFS2 creates a new endpoint socket.
func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
mnt := t.Kernel().SocketMount()
@@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addrLen uint32
if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index ead3b2b79..c847ff1c7 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
index 2a01143f6..0af805246 100644
--- a/pkg/sentry/socket/netstack/provider_vfs2.go
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index fa9ac9059..cc0fadeb5 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
- in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
- out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
0, // Icmp/InMsgs.
- Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors.
0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
@@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
0, // Icmp/OutMsgs.
- Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
- out.DstUnreachable.Value(), // OutDestUnreachs.
- out.TimeExceeded.Value(), // OutTimeExcds.
- out.ParamProblem.Value(), // OutParmProbs.
- out.SrcQuench.Value(), // OutSrcQuenchs.
- out.Redirect.Value(), // OutRedirects.
- out.Echo.Value(), // OutEchos.
- out.EchoReply.Value(), // OutEchoReps.
- out.Timestamp.Value(), // OutTimestamps.
- out.TimestampReply.Value(), // OutTimestampReps.
- out.InfoRequest.Value(), // OutAddrMasks.
- out.InfoReply.Value(), // OutAddrMaskReps.
+ Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
}
case *inet.StatSNMPTCP:
tcp := Metrics.TCP
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fd31479e5..97729dacc 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -18,6 +18,7 @@
package socket
import (
+ "bytes"
"fmt"
"sync/atomic"
"syscall"
@@ -35,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -42,7 +44,134 @@ import (
// control messages.
type ControlMessages struct {
Unix transport.ControlMessages
- IP tcpip.ControlMessages
+ IP IPControlMessages
+}
+
+// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format.
+func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ return p
+}
+
+// errOriginToLinux maps tcpip socket origin to Linux socket origin constants.
+func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 {
+ switch origin {
+ case tcpip.SockExtErrorOriginNone:
+ return linux.SO_EE_ORIGIN_NONE
+ case tcpip.SockExtErrorOriginLocal:
+ return linux.SO_EE_ORIGIN_LOCAL
+ case tcpip.SockExtErrorOriginICMP:
+ return linux.SO_EE_ORIGIN_ICMP
+ case tcpip.SockExtErrorOriginICMP6:
+ return linux.SO_EE_ORIGIN_ICMP6
+ default:
+ panic(fmt.Sprintf("unknown socket origin: %d", origin))
+ }
+}
+
+// sockErrCmsgToLinux converts SockError control message from tcpip format to
+// Linux format.
+func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg {
+ if sockErr == nil {
+ return nil
+ }
+
+ ee := linux.SockExtendedErr{
+ Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()),
+ Origin: errOriginToLinux(sockErr.ErrOrigin),
+ Type: sockErr.ErrType,
+ Code: sockErr.ErrCode,
+ Info: sockErr.ErrInfo,
+ }
+
+ switch sockErr.NetProto {
+ case header.IPv4ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet)
+ }
+ return errMsg
+ case header.IPv6ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet6)
+ }
+ return errMsg
+ default:
+ panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto))
+ }
+}
+
+// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
+// have Linux specific format) to Linux format.
+func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
+ var orgDstAddr linux.SockAddr
+ if cmgs.HasOriginalDstAddress {
+ orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
+ }
+ return IPControlMessages{
+ HasTimestamp: cmgs.HasTimestamp,
+ Timestamp: cmgs.Timestamp,
+ HasInq: cmgs.HasInq,
+ Inq: cmgs.Inq,
+ HasTOS: cmgs.HasTOS,
+ TOS: cmgs.TOS,
+ HasTClass: cmgs.HasTClass,
+ TClass: cmgs.TClass,
+ HasIPPacketInfo: cmgs.HasIPPacketInfo,
+ PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ OriginalDstAddress: orgDstAddr,
+ SockErr: sockErrCmsgToLinux(cmgs.SockErr),
+ }
+}
+
+// IPControlMessages contains socket control messages for IP sockets.
+// This can contain Linux specific structures unlike tcpip.ControlMessages.
+//
+// +stateify savable
+type IPControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo linux.ControlMessageIPPacketInfo
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress linux.SockAddr
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr linux.SockErrCMsg
}
// Release releases Unix domain socket credentials and rights.
@@ -460,3 +589,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
panic(fmt.Sprintf("Unsupported socket family %v", family))
}
}
+
+var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes()
+var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes()
+var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes()
+
+// Ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func Ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// Htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func Htons(v uint16) uint16 {
+ return Ntohs(v)
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return &out, uint32(2 + l)
+ }
+ return &out, uint32(3 + l)
+
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = Htons(addr.Port)
+ return &out, uint32(sockAddrInetSize)
+
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = Htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
+ default:
+ return nil, 0
+ }
+}
+
+// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func BytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
+//
+// AddressAndFamily returns an address and its family.
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family := usermem.ByteOrder.Uint16(addr); family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, family, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ return out, family, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, family, nil
+
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, family, nil
+
+ default:
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 4abea90cc..099a56281 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -178,10 +178,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool sets a socket option for simple cases when a value has
- // the int type.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -189,10 +185,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool gets a socket option for simple cases when a return
- // value has the int type.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
@@ -754,9 +746,6 @@ type baseEndpoint struct {
// or may be used if the endpoint is connected.
path string
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -848,17 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option.
func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- e.linger = *v
- e.Unlock()
- }
- return nil
-}
-
-func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- log.Warningf("Unsupported socket option: %d", opt)
return nil
}
@@ -872,11 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- log.Warningf("Unsupported socket option: %d", opt)
- return false, tcpip.ErrUnknownProtocolOption
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -940,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- *o = e.linger
- e.Unlock()
- return nil
-
- default:
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
- }
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
}
// LastError implements Endpoint.LastError.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index b32bb7ba8..c59297c80 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -136,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, family, err := netstack.AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
@@ -169,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -181,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -255,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -647,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -682,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index eaf0b0d26..27f705bb2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index a920180d3..d36a64ffc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -32,8 +32,8 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/netlink",
- "//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
"//pkg/usermem",
],
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index cc5f70cd4..d943a7cb1 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
- "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := netstack.AddressAndFamily(b)
+ fa, _, err := socket.AddressAndFamily(b)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index bb1f715e2..a72df62f6 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -272,7 +272,7 @@ var AMD64 = &kernel.SyscallTable{
217: syscalls.Supported("getdents64", Getdents64),
218: syscalls.Supported("set_tid_address", SetTidAddress),
219: syscalls.Supported("restart_syscall", RestartSyscall),
- 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 220: syscalls.PartiallySupported("semtimedop", Semtimedop, "A non-zero timeout argument isn't supported.", []string{"gvisor.dev/issue/137"}),
221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
222: syscalls.Supported("timer_create", TimerCreate),
223: syscalls.Supported("timer_settime", TimerSettime),
@@ -619,8 +619,8 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
- 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
+ 192: syscalls.PartiallySupported("semtimedop", Semtimedop, "A non-zero timeout argument isn't supported.", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index 0bf313a13..c2285f796 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := ctx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := ctx.Prepare(); err != nil {
+ return err
}
if eventFile != nil {
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 519066a47..c33571f43 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -175,6 +175,12 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
}
}
+ file, err := d.Inode.GetFile(t, d, fileFlags)
+ if err != nil {
+ return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ }
+ defer file.DecRef(t)
+
// Truncate is called when O_TRUNC is specified for any kind of
// existing Dirent. Behavior is delegated to the entry's Truncate
// implementation.
@@ -184,12 +190,6 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
}
}
- file, err := d.Inode.GetFile(t, d, fileFlags)
- if err != nil {
- return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
- }
- defer file.DecRef(t)
-
// Success.
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
CloseOnExec: flags&linux.O_CLOEXEC != 0,
@@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
- fSetOwn(t, file, set)
+ fSetOwn(t, int(fd), file, set)
return 0, nil, nil
case linux.FIOGETOWN, linux.SIOCGPGRP:
@@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 {
//
// If who is positive, it represents a PID. If negative, it represents a PGID.
// If the PID or PGID is invalid, the owner is silently unset.
-func fSetOwn(t *kernel.Task, file *fs.File, who int32) error {
- a := file.Async(fasync.New).(*fasync.FileAsync)
+func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error {
+ a := file.Async(fasync.New(fd)).(*fasync.FileAsync)
if who < 0 {
// Check for overflow before flipping the sign.
if who-1 > who {
@@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN:
return uintptr(fGetOwn(t, file)), nil, nil
case linux.F_SETOWN:
- return 0, nil, fSetOwn(t, file, args[2].Int())
+ return 0, nil, fSetOwn(t, int(fd), file, args[2].Int())
case linux.F_GETOWN_EX:
addr := args[2].Pointer()
owner := fGetOwnEx(t, file)
@@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- a := file.Async(fasync.New).(*fasync.FileAsync)
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
switch owner.Type {
case linux.F_OWNER_TID:
task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
@@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
n, err := sz.SetFifoSize(int64(args[2].Int()))
return uintptr(n), nil, err
+ case linux.F_GETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return uintptr(a.Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index e383a0a87..d324461a3 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -48,6 +48,15 @@ func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return uintptr(set.ID), nil, nil
}
+// Semtimedop handles: semop(int semid, struct sembuf *sops, size_t nsops, const struct timespec *timeout)
+func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/137): A non-zero timeout isn't supported.
+ if args[3].Pointer() != 0 {
+ return 0, nil, syserror.ENOSYS
+ }
+ return Semop(t, args)
+}
+
// Semop handles: semop(int semid, struct sembuf *sops, size_t nsops)
func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := args[0].Int()
@@ -146,11 +155,37 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
v, err := getNCnt(t, id, num)
return uintptr(v), nil, err
- case linux.IPC_INFO,
- linux.SEM_INFO,
- linux.SEM_STAT,
- linux.SEM_STAT_ANY:
+ case linux.IPC_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.IPCInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+ case linux.SEM_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.SemInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+
+ case linux.SEM_STAT:
+ arg := args[3].Pointer()
+ // id is an index in SEM_STAT.
+ semid, ds, err := semStat(t, id)
+ if err != nil {
+ return 0, nil, err
+ }
+ if _, err := ds.CopyOut(t, arg); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(semid), nil, err
+
+ case linux.SEM_STAT_ANY:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
@@ -195,6 +230,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
return set.GetStat(creds)
}
+func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByIndex(index)
+ if set == nil {
+ return 0, nil, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ ds, err := set.GetStat(creds)
+ return set.ID, ds, err
+}
+
func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index e748d33d8..d639c9bf7 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(target.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(target.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
if err := target.SendGroupSignal(info); err != syserror.ESRCH {
return 0, nil, err
}
@@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
err := tg.SendSignal(info)
if err == syserror.ESRCH {
// ESRCH is ignored because it means the task
@@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
// See note above regarding ESRCH race above.
if err := tg.SendSignal(info); err != syserror.ESRCH {
lastErr = err
@@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI
Signo: int32(sig),
Code: arch.SignalInfoTkill,
}
- info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 9cd052c3d..fe45225c1 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -749,11 +749,6 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
return 0, err
}
- // FIXME(b/63594852): Pretend we have an empty error queue.
- if flags&linux.MSG_ERRQUEUE != 0 {
- return 0, syserror.EAGAIN
- }
-
// Fast path when no control message nor name buffers are provided.
if msg.ControlLen == 0 && msg.NameLen == 0 {
n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
@@ -1035,7 +1030,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
return 0, err
}
- controlMessages, err := control.Parse(t, s, controlData)
+ controlMessages, err := control.Parse(t, s, controlData, t.Arch().Width())
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 983f8d396..8e7ac0ffe 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
si := arch.SignalInfo{
Signo: int32(linux.SIGCHLD),
}
- si.SetPid(int32(wr.TID))
- si.SetUid(int32(wr.UID))
+ si.SetPID(int32(wr.TID))
+ si.SetUID(int32(wr.UID))
// TODO(b/73541790): convert kernel.ExitStatus to functions and make
// WaitResult.Status a linux.WaitStatus.
s := syscall.WaitStatus(wr.Status)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index 6d0a38330..1365a5a62 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := aioCtx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := aioCtx.Prepare(); err != nil {
+ return err
}
if eventFD != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 36e89700e..7dd9ef857 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
case linux.F_GETOWN_EX:
owner, hasOwner := getAsyncOwner(t, file)
if !hasOwner {
@@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID)
+ return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID)
case linux.F_SETPIPE_SZ:
pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
if !ok {
@@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
case linux.F_SETLK, linux.F_SETLKW:
return 0, nil, posixLock(t, args, file, cmd)
+ case linux.F_GETSIG:
+ a := file.AsyncHandler()
+ if a == nil {
+ // Default behavior aka SIGIO.
+ return 0, nil, nil
+ }
+ return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
@@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne
}
}
-func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error {
+func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error {
switch ownerType {
case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
// Acceptable type.
@@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32
return syserror.EINVAL
}
- a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync)
+ a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync)
if pid == 0 {
a.ClearOwner()
return nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 2806c3f6f..20c264fef 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
}
ret, err := file.Ioctl(t, t.MemoryManager(), args)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
index ee38fdca0..6986e39fe 100644
--- a/pkg/sentry/syscalls/linux/vfs2/pipe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -42,7 +42,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
return syserror.EINVAL
}
- r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ r, w, err := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ if err != nil {
+ return err
+ }
defer r.DecRef(t)
defer w.DecRef(t)
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 7b33b3f59..f5795b4a8 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -752,11 +752,6 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
return 0, err
}
- // FIXME(b/63594852): Pretend we have an empty error queue.
- if flags&linux.MSG_ERRQUEUE != 0 {
- return 0, syserror.EAGAIN
- }
-
// Fast path when no control message nor name buffers are provided.
if msg.ControlLen == 0 && msg.NameLen == 0 {
n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
@@ -1038,7 +1033,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
return 0, err
}
- controlMessages, err := control.Parse(t, s, controlData)
+ controlMessages, err := control.Parse(t, s, controlData, t.Arch().Width())
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index a98aac52b..072655fe8 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
// Add epi to file.epolls so that it is removed when the last
@@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready with the new mask.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
return nil
@@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error
}
// Callback implements waiter.EntryCallback.Callback.
-func (epi *epollInterest) Callback(*waiter.Entry) {
+func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) {
newReady := false
epi.epoll.mu.Lock()
if !epi.ready {
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index f9e39a94c..5321ac80a 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -15,6 +15,7 @@
package vfs
import (
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -43,7 +44,7 @@ import (
type FileDescription struct {
FileDescriptionRefs
- // flagsMu protects statusFlags and asyncHandler below.
+ // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below.
flagsMu sync.Mutex `state:"nosave"`
// statusFlags contains status flags, "initialized by open(2) and possibly
@@ -52,6 +53,11 @@ type FileDescription struct {
// access to asyncHandler.
statusFlags uint32
+ // saved is true after beforeSave is called. This is used to prevent
+ // double-unregistration of asyncHandler. This does not work properly for
+ // save-resume, which is not currently supported in gVisor (see b/26588733).
+ saved bool `state:"nosave"`
+
// asyncHandler handles O_ASYNC signal generation. It is set with the
// F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must
// also be set by fcntl(2).
@@ -184,7 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) {
}
fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
- if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
fd.asyncHandler.Unregister(fd)
}
fd.asyncHandler = nil
@@ -834,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn
return fd.asyncHandler
}
-// FileReadWriteSeeker is a helper struct to pass a FileDescription as
-// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
-type FileReadWriteSeeker struct {
- FD *FileDescription
- Ctx context.Context
- ROpts ReadOptions
- WOpts WriteOptions
-}
-
-// ReadAt implements io.ReaderAt.ReadAt.
-func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
- return int(n), err
-}
-
-// Read implements io.ReadWriteSeeker.Read.
-func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
- return int(n), err
-}
-
-// Seek implements io.ReadWriteSeeker.Seek.
-func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
- return f.FD.Seek(f.Ctx, offset, int32(whence))
-}
-
-// WriteAt implements io.WriterAt.WriteAt.
-func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
- return int(n), err
-}
-
-// Write implements io.ReadWriteSeeker.Write.
-func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
- buf := usermem.BytesIOSequence(p)
- n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
- return int(n), err
+// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD
+// returns EOF or an error. It returns the number of bytes copied.
+func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) {
+ done := int64(0)
+ buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
+ for {
+ readN, readErr := srcFD.Read(ctx, buf, ReadOptions{})
+ if readErr != nil && readErr != io.EOF {
+ return done, readErr
+ }
+ src := buf.TakeFirst64(readN)
+ for src.NumBytes() != 0 {
+ writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{})
+ done += writeN
+ src = src.DropFirst64(writeN)
+ if writeErr != nil {
+ return done, writeErr
+ }
+ }
+ if readErr == io.EOF {
+ return done, nil
+ }
+ }
}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index cb48c37a1..0df023713 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package vfs
import (
@@ -41,6 +36,15 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
+var (
+ mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
+ mountKeySeed = sync.RandUintptr()
+)
+
+func (k *mountKey) hash() uintptr {
+ return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
+}
+
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry {
}
}
-func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
-
// Invariant: mnt.key.parent == nil. vd.Ok().
func (mnt *Mount) setKey(vd VirtualDentry) {
atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
}
-func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
-
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
-//
-// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -84,8 +82,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount `state:"nosave"`
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -150,7 +147,6 @@ func init() {
// Init must be called exactly once on each mountTable before use.
func (mt *mountTable) Init() {
- mt.seed = rand32()
mt.size = mtInitOrder
mt.slots = newMountTableSlots(mtInitCap)
}
@@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := key.hash()
loop:
for {
@@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must not already contain a Mount with the same mount point and parent.
func (mt *mountTable) insertSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
// We're under the maximum load factor if:
//
@@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must contain mount.
func (mt *mountTable) removeSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
slots := mt.slots
@@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) {
off = (off + mountSlotBytes) & offmask
}
}
-
-//go:linkname memhash runtime.memhash
-func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
-
-//go:linkname rand32 runtime.fastrand
-func rand32() uint32
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
index 7723ed643..8998a82dd 100644
--- a/pkg/sentry/vfs/save_restore.go
+++ b/pkg/sentry/vfs/save_restore.go
@@ -18,8 +18,10 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// FilesystemImplSaveRestoreExtension is an optional extension to
@@ -99,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount {
return mounts
}
+// saveKey is called by stateify.
+func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
+
// loadMounts is called by stateify.
func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
if mounts == nil {
@@ -110,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
}
}
+// loadKey is called by stateify.
+func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
+
func (mnt *Mount) afterLoad() {
if atomic.LoadInt64(&mnt.refs) != 0 {
refsvfs2.Register(mnt)
@@ -120,5 +128,20 @@ func (mnt *Mount) afterLoad() {
func (epi *epollInterest) afterLoad() {
// Mark all epollInterests as ready after restore so that the next call to
// EpollInstance.ReadEvents() rechecks their readiness.
- epi.Callback(nil)
+ epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask))
+}
+
+// beforeSave is called by stateify.
+func (fd *FileDescription) beforeSave() {
+ fd.saved = true
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Unregister(fd)
+ }
+}
+
+// afterLoad is called by stateify.
+func (fd *FileDescription) afterLoad() {
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Register(fd)
+ }
}
diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go
index d462c3eef..e8315326d 100644
--- a/pkg/shim/v1/proc/process.go
+++ b/pkg/shim/v1/proc/process.go
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package proc contains process-related utilities.
package proc
import (
diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD
index 05c595bc9..e5b6bf186 100644
--- a/pkg/shim/v1/shim/BUILD
+++ b/pkg/shim/v1/shim/BUILD
@@ -8,6 +8,7 @@ go_library(
"api.go",
"platform.go",
"service.go",
+ "shim.go",
],
visibility = [
"//pkg/shim:__subpackages__",
diff --git a/pkg/sleep/commit_asm.go b/pkg/shim/v1/shim/shim.go
index 75728a97d..1855a8769 100644
--- a/pkg/sleep/commit_asm.go
+++ b/pkg/shim/v1/shim/shim.go
@@ -1,10 +1,11 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
-// http://www.apache.org/licenses/LICENSE-2.0
+// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,9 +13,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 arm64
-
-package sleep
-
-// See commit_noasm.go for a description of commitSleep.
-func commitSleep(g uintptr, waitingG *uintptr) bool
+// Package shim contains the core containerd shim implementation.
+package shim
diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go
index 07e346654..21e75d16d 100644
--- a/pkg/shim/v1/utils/utils.go
+++ b/pkg/shim/v1/utils/utils.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package utils contains utility functions.
package utils
import (
diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD
index f37fefddc..b0e8daa51 100644
--- a/pkg/shim/v2/BUILD
+++ b/pkg/shim/v2/BUILD
@@ -22,6 +22,7 @@ go_library(
"//runsc/specutils",
"@com_github_burntsushi_toml//:go_default_library",
"@com_github_containerd_cgroups//:go_default_library",
+ "@com_github_containerd_cgroups//stats/v1:go_default_library",
"@com_github_containerd_console//:go_default_library",
"@com_github_containerd_containerd//api/events:go_default_library",
"@com_github_containerd_containerd//api/types/task:go_default_library",
diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go
index 2e39d2c4a..6aaf5fab8 100644
--- a/pkg/shim/v2/service.go
+++ b/pkg/shim/v2/service.go
@@ -28,6 +28,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/containerd/cgroups"
+ cgroupsstats "github.com/containerd/cgroups/stats/v1"
"github.com/containerd/console"
"github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/api/types/task"
@@ -67,9 +68,15 @@ var (
var _ = (taskAPI.TaskService)(&service{})
-// configFile is the default config file name. For containerd 1.2,
-// we assume that a config.toml should exist in the runtime root.
-const configFile = "config.toml"
+const (
+ // configFile is the default config file name. For containerd 1.2,
+ // we assume that a config.toml should exist in the runtime root.
+ configFile = "config.toml"
+
+ // shimAddressPath is the relative path to a file that contains the address
+ // to the shim UDS. See service.shimAddress.
+ shimAddressPath = "address"
+)
// New returns a new shim service that can be used via GRPC.
func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) {
@@ -101,6 +108,11 @@ func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()
return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
}
go s.forward(ctx, publisher)
+
+ if address, err := shim.ReadAddress(shimAddressPath); err == nil {
+ s.shimAddress = address
+ }
+
return s, nil
}
@@ -152,6 +164,9 @@ type service struct {
// cancel is a function that needs to be called before the shim stops. The
// function is provided by the caller to New().
cancel func()
+
+ // shimAddress is the location of the UDS used to communicate to containerd.
+ shimAddress string
}
func (s *service) newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) {
@@ -191,38 +206,58 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container
if err != nil {
return "", err
}
- address, err := shim.SocketAddress(ctx, id)
+ address, err := shim.SocketAddress(ctx, containerdAddress, id)
if err != nil {
return "", err
}
socket, err := shim.NewSocket(address)
if err != nil {
- return "", err
+ // The only time where this would happen is if there is a bug and the socket
+ // was not cleaned up in the cleanup method of the shim or we are using the
+ // grouping functionality where the new process should be run with the same
+ // shim as an existing container.
+ if !shim.SocketEaddrinuse(err) {
+ return "", fmt.Errorf("create new shim socket: %w", err)
+ }
+ if shim.CanConnect(address) {
+ if err := shim.WriteAddress(shimAddressPath, address); err != nil {
+ return "", fmt.Errorf("write existing socket for shim: %w", err)
+ }
+ return address, nil
+ }
+ if err := shim.RemoveSocket(address); err != nil {
+ return "", fmt.Errorf("remove pre-existing socket: %w", err)
+ }
+ if socket, err = shim.NewSocket(address); err != nil {
+ return "", fmt.Errorf("try create new shim socket 2x: %w", err)
+ }
}
- defer socket.Close()
+ cu := cleanup.Make(func() {
+ socket.Close()
+ _ = shim.RemoveSocket(address)
+ })
+ defer cu.Clean()
+
f, err := socket.File()
if err != nil {
return "", err
}
- defer f.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, f)
log.L.Debugf("Executing: %q %s", cmd.Path, cmd.Args)
if err := cmd.Start(); err != nil {
+ f.Close()
return "", err
}
- cu := cleanup.Make(func() {
- cmd.Process.Kill()
- })
- defer cu.Clean()
+ cu.Add(func() { cmd.Process.Kill() })
// make sure to wait after start
go cmd.Wait()
if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil {
return "", err
}
- if err := shim.WriteAddress("address", address); err != nil {
+ if err := shim.WriteAddress(shimAddressPath, address); err != nil {
return "", err
}
if err := shim.SetScore(cmd.Process.Pid); err != nil {
@@ -675,8 +710,11 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task
func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) {
log.L.Debugf("Shutdown, id: %s", r.ID)
s.cancel()
+ if s.shimAddress != "" {
+ _ = shim.RemoveSocket(s.shimAddress)
+ }
os.Exit(0)
- return empty, nil
+ panic("Should not get here")
}
func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) {
@@ -698,48 +736,48 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.
// as runc.
//
// [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81
- metrics := &cgroups.Metrics{
- CPU: &cgroups.CPUStat{
- Usage: &cgroups.CPUUsage{
+ metrics := &cgroupsstats.Metrics{
+ CPU: &cgroupsstats.CPUStat{
+ Usage: &cgroupsstats.CPUUsage{
Total: stats.Cpu.Usage.Total,
Kernel: stats.Cpu.Usage.Kernel,
User: stats.Cpu.Usage.User,
PerCPU: stats.Cpu.Usage.Percpu,
},
- Throttling: &cgroups.Throttle{
+ Throttling: &cgroupsstats.Throttle{
Periods: stats.Cpu.Throttling.Periods,
ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods,
ThrottledTime: stats.Cpu.Throttling.ThrottledTime,
},
},
- Memory: &cgroups.MemoryStat{
+ Memory: &cgroupsstats.MemoryStat{
Cache: stats.Memory.Cache,
- Usage: &cgroups.MemoryEntry{
+ Usage: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Usage.Limit,
Usage: stats.Memory.Usage.Usage,
Max: stats.Memory.Usage.Max,
Failcnt: stats.Memory.Usage.Failcnt,
},
- Swap: &cgroups.MemoryEntry{
+ Swap: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Swap.Limit,
Usage: stats.Memory.Swap.Usage,
Max: stats.Memory.Swap.Max,
Failcnt: stats.Memory.Swap.Failcnt,
},
- Kernel: &cgroups.MemoryEntry{
+ Kernel: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.Kernel.Limit,
Usage: stats.Memory.Kernel.Usage,
Max: stats.Memory.Kernel.Max,
Failcnt: stats.Memory.Kernel.Failcnt,
},
- KernelTCP: &cgroups.MemoryEntry{
+ KernelTCP: &cgroupsstats.MemoryEntry{
Limit: stats.Memory.KernelTCP.Limit,
Usage: stats.Memory.KernelTCP.Usage,
Max: stats.Memory.KernelTCP.Max,
Failcnt: stats.Memory.KernelTCP.Failcnt,
},
},
- Pids: &cgroups.PidsStat{
+ Pids: &cgroupsstats.PidsStat{
Current: stats.Pids.Current,
Limit: stats.Pids.Limit,
},
@@ -843,9 +881,7 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er
func (s *service) forward(ctx context.Context, publisher shim.Publisher) {
for e := range s.events {
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := publisher.Publish(ctx, getTopic(e), e)
- cancel()
if err != nil {
// Should not happen.
panic(fmt.Errorf("post event: %w", err))
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index ae0fe1522..48bcdd62b 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -5,10 +5,6 @@ package(licenses = ["notice"])
go_library(
name = "sleep",
srcs = [
- "commit_amd64.s",
- "commit_arm64.s",
- "commit_asm.go",
- "commit_noasm.go",
"sleep_unsafe.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/sleep/commit_amd64.s b/pkg/sleep/commit_amd64.s
deleted file mode 100644
index bc4ac2c3c..000000000
--- a/pkg/sleep/commit_amd64.s
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "textflag.h"
-
-#define preparingG 1
-
-// See commit_noasm.go for a description of commitSleep.
-//
-// func commitSleep(g uintptr, waitingG *uintptr) bool
-TEXT ·commitSleep(SB),NOSPLIT,$0-24
- MOVQ waitingG+8(FP), CX
- MOVQ g+0(FP), DX
-
- // Store the G in waitingG if it's still preparingG. If it's anything
- // else it means a waker has aborted the sleep.
- MOVQ $preparingG, AX
- LOCK
- CMPXCHGQ DX, 0(CX)
-
- SETEQ AX
- MOVB AX, ret+16(FP)
-
- RET
diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s
deleted file mode 100644
index d0ef15b20..000000000
--- a/pkg/sleep/commit_arm64.s
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "textflag.h"
-
-#define preparingG 1
-
-// See commit_noasm.go for a description of commitSleep.
-//
-// func commitSleep(g uintptr, waitingG *uintptr) bool
-TEXT ·commitSleep(SB),NOSPLIT,$0-24
- MOVD waitingG+8(FP), R0
- MOVD $preparingG, R1
- MOVD G+0(FP), R2
-
- // Store the G in waitingG if it's still preparingG. If it's anything
- // else it means a waker has aborted the sleep.
-again:
- LDAXR (R0), R3
- CMP R1, R3
- BNE ok
- STLXR R2, (R0), R3
- CBNZ R3, again
-ok:
- CSET EQ, R0
- MOVB R0, ret+16(FP)
- RET
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
deleted file mode 100644
index f59061f37..000000000
--- a/pkg/sleep/commit_noasm.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build !race
-// +build !amd64,!arm64
-
-package sleep
-
-import "sync/atomic"
-
-// commitSleep signals to wakers that the given g is now sleeping. Wakers can
-// then fetch it and wake it.
-//
-// The commit may fail if wakers have been asserted after our last check, in
-// which case they will have set s.waitingG to zero.
-//
-// It is written in assembly because it is called from g0, so it doesn't have
-// a race context.
-func commitSleep(g uintptr, waitingG *uintptr) bool {
- // Try to store the G so that wakers know who to wake.
- return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
-}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index 19bce2afb..c44206b1e 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.11
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
// Package sleep allows goroutines to efficiently sleep on multiple sources of
// notifications (wakers). It offers O(1) complexity, which is different from
// multi-channel selects which have O(n) complexity (where n is the number of
@@ -91,12 +86,6 @@ var (
assertedSleeper Sleeper
)
-//go:linkname gopark runtime.gopark
-func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason uint8, traceEv byte, traceskip int)
-
-//go:linkname goready runtime.goready
-func goready(g uintptr, traceskip int)
-
// Sleeper allows a goroutine to sleep and receive wake up notifications from
// Wakers in an efficient way.
//
@@ -189,7 +178,7 @@ func (s *Sleeper) nextWaker(block bool) *Waker {
// See:runtime2.go in the go runtime package for
// the values to pass as the waitReason here.
const waitReasonSelect = 9
- gopark(commitSleep, &s.waitingG, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(commitSleep, unsafe.Pointer(&s.waitingG), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
}
// Pull the shared list out and reverse it in the local
@@ -212,6 +201,18 @@ func (s *Sleeper) nextWaker(block bool) *Waker {
return w
}
+// commitSleep signals to wakers that the given g is now sleeping. Wakers can
+// then fetch it and wake it.
+//
+// The commit may fail if wakers have been asserted after our last check, in
+// which case they will have set s.waitingG to zero.
+//
+//go:norace
+//go:nosplit
+func commitSleep(g uintptr, waitingG unsafe.Pointer) bool {
+ return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(waitingG), preparingG, g)
+}
+
// Fetch fetches the next wake-up notification. If a notification is immediately
// available, it is returned right away. Otherwise, the behavior depends on the
// value of 'block': if true, the current goroutine blocks until a notification
@@ -311,7 +312,7 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
case 0, preparingG:
default:
// We managed to get a G. Wake it up.
- goready(g, 0)
+ sync.Goready(g, 0)
}
}
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
index d3931c952..2b1609af0 100644
--- a/pkg/state/tests/integer_test.go
+++ b/pkg/state/tests/integer_test.go
@@ -20,21 +20,21 @@ import (
)
var (
- allIntTs = []int{-1, 0, 1}
- allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
- allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
- allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
- allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
- allUintTs = []uint{0, 1}
- allUintptrs = []uintptr{0, 1, ^uintptr(0)}
- allUint8s = []uint8{0, 1, math.MaxUint8}
- allUint16s = []uint16{0, 1, math.MaxUint16}
- allUint32s = []uint32{0, 1, math.MaxUint32}
- allUint64s = []uint64{0, 1, math.MaxUint64}
+ allBasicInts = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allBasicUints = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
)
var allInts = flatten(
- allIntTs,
+ allBasicInts,
allInt8s,
allInt16s,
allInt32s,
@@ -42,7 +42,7 @@ var allInts = flatten(
)
var allUints = flatten(
- allUintTs,
+ allBasicUints,
allUintptrs,
allUint8s,
allUint16s,
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
index c829753cc..75bdbfc6e 100644
--- a/pkg/state/tests/register_test.go
+++ b/pkg/state/tests/register_test.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+
package tests
import (
@@ -165,3 +167,12 @@ func TestRegisterBad(t *testing.T) {
}
}
+
+func TestRegisterTypeOnlyStruct(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Register did not panic")
+ }
+ }()
+ state.Register((*typeOnlyEmptyStruct)(nil))
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
index c91c2c032..9826f1ee9 100644
--- a/pkg/state/tests/struct_test.go
+++ b/pkg/state/tests/struct_test.go
@@ -17,8 +17,6 @@ package tests
import (
"math/rand"
"testing"
-
- "gvisor.dev/gvisor/pkg/state"
)
func TestEmptyStruct(t *testing.T) {
@@ -58,15 +56,6 @@ func TestEmptyStruct(t *testing.T) {
})
}
-func TestRegisterTypeOnlyStruct(t *testing.T) {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("Register did not panic")
- }
- }()
- state.Register((*typeOnlyEmptyStruct)(nil))
-}
-
func TestEmbeddedPointers(t *testing.T) {
// Give each int64 a random value to prevent Go from using
// runtime.staticuint64s, which confounds tests for struct duplication.
diff --git a/pkg/state/types.go b/pkg/state/types.go
index 84aed8732..420675880 100644
--- a/pkg/state/types.go
+++ b/pkg/state/types.go
@@ -329,47 +329,48 @@ var reverseTypeDatabase = map[reflect.Type]string{}
// This must be called on init and only done once.
func Register(t Type) {
name := t.StateTypeName()
- fields := t.StateFields()
- assertValidType(name, fields)
- // Register must always be called on pointers.
typ := reflect.TypeOf(t)
- if typ.Kind() != reflect.Ptr {
- Failf("Register must be called on pointers")
+ if raceEnabled {
+ assertValidType(name, t.StateFields())
+ // Register must always be called on pointers.
+ if typ.Kind() != reflect.Ptr {
+ Failf("Register must be called on pointers")
+ }
}
typ = typ.Elem()
- if typ.Kind() == reflect.Struct {
- // All registered structs must implement SaverLoader. We allow
- // the registration is non-struct types with just the Type
- // interface, but we need to call StateSave/StateLoad methods
- // on aggregate types.
- if _, ok := t.(SaverLoader); !ok {
- Failf("struct %T does not implement SaverLoader", t)
+ if raceEnabled {
+ if typ.Kind() == reflect.Struct {
+ // All registered structs must implement SaverLoader. We allow
+ // the registration is non-struct types with just the Type
+ // interface, but we need to call StateSave/StateLoad methods
+ // on aggregate types.
+ if _, ok := t.(SaverLoader); !ok {
+ Failf("struct %T does not implement SaverLoader", t)
+ }
+ } else {
+ // Non-structs must not have any fields. We don't support
+ // calling StateSave/StateLoad methods on any non-struct types.
+ // If custom behavior is required, these types should be
+ // wrapped in a structure of some kind.
+ if fields := t.StateFields(); len(fields) != 0 {
+ Failf("non-struct %T has non-zero fields %v", t, fields)
+ }
+ // We don't allow non-structs to implement StateSave/StateLoad
+ // methods, because they won't be called and it's confusing.
+ if _, ok := t.(SaverLoader); ok {
+ Failf("non-struct %T implements SaverLoader", t)
+ }
}
- } else {
- // Non-structs must not have any fields. We don't support
- // calling StateSave/StateLoad methods on any non-struct types.
- // If custom behavior is required, these types should be
- // wrapped in a structure of some kind.
- if len(fields) != 0 {
- Failf("non-struct %T has non-zero fields %v", t, fields)
+ if _, ok := primitiveTypeDatabase[name]; ok {
+ Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
}
- // We don't allow non-structs to implement StateSave/StateLoad
- // methods, because they won't be called and it's confusing.
- if _, ok := t.(SaverLoader); ok {
- Failf("non-struct %T implements SaverLoader", t)
+ if _, ok := globalTypeDatabase[name]; ok {
+ Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
+ }
+ if name == interfaceType {
+ Failf("conflicting name for %T: matches interfaceType", t)
}
- }
- if _, ok := primitiveTypeDatabase[name]; ok {
- Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
- }
- if _, ok := globalTypeDatabase[name]; ok {
- Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
- }
- if name == interfaceType {
- Failf("conflicting name for %T: matches interfaceType", t)
- }
- globalTypeDatabase[name] = typ
- if raceEnabled {
reverseTypeDatabase[typ] = name
}
+ globalTypeDatabase[name] = typ
}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 5bd4d09d1..28e62abbb 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -10,15 +10,34 @@ exports_files(["LICENSE"])
go_template(
name = "generic_atomicptr",
- srcs = ["atomicptr_unsafe.go"],
+ srcs = ["generic_atomicptr_unsafe.go"],
types = [
"Value",
],
)
go_template(
+ name = "generic_atomicptrmap",
+ srcs = ["generic_atomicptrmap_unsafe.go"],
+ opt_consts = [
+ "ShardOrder",
+ ],
+ opt_types = [
+ "Hasher",
+ ],
+ types = [
+ "Key",
+ "Value",
+ ],
+ deps = [
+ ":sync",
+ "//pkg/gohacks",
+ ],
+)
+
+go_template(
name = "generic_seqatomic",
- srcs = ["seqatomic_unsafe.go"],
+ srcs = ["generic_seqatomic_unsafe.go"],
types = [
"Value",
],
@@ -33,15 +52,17 @@ go_library(
"aliases.go",
"checklocks_off_unsafe.go",
"checklocks_on_unsafe.go",
- "memmove_unsafe.go",
+ "goyield_go113_unsafe.go",
+ "goyield_unsafe.go",
"mutex_unsafe.go",
"nocopy.go",
"norace_unsafe.go",
+ "race_amd64.s",
+ "race_arm64.s",
"race_unsafe.go",
+ "runtime_unsafe.go",
"rwmutex_unsafe.go",
"seqcount.go",
- "spin_legacy_unsafe.go",
- "spin_unsafe.go",
"sync.go",
],
marshal = False,
diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmaptest/BUILD
new file mode 100644
index 000000000..3f71ae97d
--- /dev/null
+++ b/pkg/sync/atomicptrmaptest/BUILD
@@ -0,0 +1,57 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:private"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "test_atomicptrmap",
+ out = "test_atomicptrmap_unsafe.go",
+ package = "atomicptrmap",
+ prefix = "test",
+ template = "//pkg/sync:generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_template_instance(
+ name = "test_atomicptrmap_sharded",
+ out = "test_atomicptrmap_sharded_unsafe.go",
+ consts = {
+ "ShardOrder": "4",
+ },
+ package = "atomicptrmap",
+ prefix = "test",
+ suffix = "Sharded",
+ template = "//pkg/sync:generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_library(
+ name = "atomicptrmap",
+ testonly = 1,
+ srcs = [
+ "atomicptrmap.go",
+ "test_atomicptrmap_sharded_unsafe.go",
+ "test_atomicptrmap_unsafe.go",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "atomicptrmap_test",
+ size = "small",
+ srcs = ["atomicptrmap_test.go"],
+ library = ":atomicptrmap",
+ deps = ["//pkg/sync"],
+)
diff --git a/tools/vm/test.cc b/pkg/sync/atomicptrmaptest/atomicptrmap.go
index c0ceacda1..867821ce9 100644
--- a/tools/vm/test.cc
+++ b/pkg/sync/atomicptrmaptest/atomicptrmap.go
@@ -12,16 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "gtest/gtest.h"
+// Package atomicptrmap instantiates generic_atomicptrmap for testing.
+package atomicptrmap
-namespace {
-
-TEST(Image, Sanity0) {
- // Do nothing (in shard 0).
-}
-
-TEST(Image, Sanity1) {
- // Do nothing (in shard 1).
+type testValue struct {
+ val int
}
-
-} // namespace
diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
new file mode 100644
index 000000000..75a9997ef
--- /dev/null
+++ b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
@@ -0,0 +1,635 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package atomicptrmap
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestConsistencyWithGoMap(t *testing.T) {
+ const maxKey = 16
+ var vals [4]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ vals[i] = new(testValue)
+ }
+ var (
+ m = make(map[int64]*testValue)
+ apm testAtomicPtrMap
+ )
+ for i := 0; i < 100000; i++ {
+ // Apply a random operation to both m and apm and expect them to have
+ // the same result. Bias toward CompareAndSwap, which has the most
+ // cases; bias away from Range and RangeRepeatable, which are
+ // relatively expensive.
+ switch rand.Intn(10) {
+ case 0, 1: // Load
+ key := rand.Int63n(maxKey)
+ want := m[key]
+ got := apm.Load(key)
+ t.Logf("Load(%d) = %p", key, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 2, 3: // Swap
+ key := rand.Int63n(maxKey)
+ val := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if val != nil {
+ m[key] = val
+ } else {
+ delete(m, key)
+ }
+ got := apm.Swap(key, val)
+ t.Logf("Swap(%d, %p) = %p", key, val, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 4, 5, 6, 7: // CompareAndSwap
+ key := rand.Int63n(maxKey)
+ oldVal := vals[rand.Intn(len(vals))]
+ newVal := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if want == oldVal {
+ if newVal != nil {
+ m[key] = newVal
+ } else {
+ delete(m, key)
+ }
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 8: // Range
+ got := make(map[int64]*testValue)
+ var (
+ haveDup = false
+ dup int64
+ )
+ apm.Range(func(key int64, val *testValue) bool {
+ if _, ok := got[key]; ok && !haveDup {
+ haveDup = true
+ dup = key
+ }
+ got[key] = val
+ return true
+ })
+ t.Logf("Range() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ if haveDup {
+ t.Fatalf("got duplicate key %d", dup)
+ }
+ case 9: // RangeRepeatable
+ got := make(map[int64]*testValue)
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ got[key] = val
+ return true
+ })
+ t.Logf("RangeRepeatable() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ }
+ }
+}
+
+func TestConcurrentHeterogeneous(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ var (
+ apm testAtomicPtrMap
+ wg sync.WaitGroup
+ )
+ defer func() {
+ cancel()
+ wg.Wait()
+ }()
+
+ possibleKeyValuePairs := make(map[int64]map[*testValue]struct{})
+ addKeyValuePair := func(key int64, val *testValue) {
+ values := possibleKeyValuePairs[key]
+ if values == nil {
+ values = make(map[*testValue]struct{})
+ possibleKeyValuePairs[key] = values
+ }
+ values[val] = struct{}{}
+ }
+
+ const numValuesPerKey = 4
+
+ // These goroutines use keys not used by any other goroutine.
+ const numPrivateKeys = 3
+ for i := 0; i < numPrivateKeys; i++ {
+ key := int64(i)
+ var vals [numValuesPerKey]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ val := new(testValue)
+ vals[i] = val
+ addKeyValuePair(key, val)
+ }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ var stored *testValue
+ for ctx.Err() == nil {
+ switch r.Intn(4) {
+ case 0:
+ got := apm.Load(key)
+ if got != stored {
+ t.Errorf("Load(%d): got %p, wanted %p", key, got, stored)
+ return
+ }
+ case 1:
+ val := vals[r.Intn(len(vals))]
+ want := stored
+ stored = val
+ got := apm.Swap(key, val)
+ if got != want {
+ t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want)
+ return
+ }
+ case 2, 3:
+ oldVal := vals[r.Intn(len(vals))]
+ newVal := vals[r.Intn(len(vals))]
+ want := stored
+ if stored == oldVal {
+ stored = newVal
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ if got != want {
+ t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // These goroutines share a small set of keys.
+ const numSharedKeys = 2
+ var (
+ sharedKeys [numSharedKeys]int64
+ sharedValues = make(map[int64][]*testValue)
+ sharedValuesSet = make(map[int64]map[*testValue]struct{})
+ )
+ for i := range sharedKeys {
+ key := int64(numPrivateKeys + i)
+ sharedKeys[i] = key
+ vals := make([]*testValue, numValuesPerKey)
+ valsSet := make(map[*testValue]struct{})
+ for j := range vals {
+ val := new(testValue)
+ vals[j] = val
+ valsSet[val] = struct{}{}
+ addKeyValuePair(key, val)
+ }
+ sharedValues[key] = vals
+ sharedValuesSet[key] = valsSet
+ }
+ randSharedValue := func(r *rand.Rand, key int64) *testValue {
+ vals := sharedValues[key]
+ return vals[r.Intn(len(vals))]
+ }
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ keyIndex := r.Intn(len(sharedKeys))
+ key := sharedKeys[keyIndex]
+ var (
+ op string
+ got *testValue
+ )
+ switch r.Intn(4) {
+ case 0:
+ op = "Load"
+ got = apm.Load(key)
+ case 1:
+ op = "Swap"
+ got = apm.Swap(key, randSharedValue(r, key))
+ case 2, 3:
+ op = "CompareAndSwap"
+ got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key))
+ }
+ if got != nil {
+ valsSet := sharedValuesSet[key]
+ if _, ok := valsSet[got]; !ok {
+ t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // This goroutine repeatedly searches for unused keys.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ key := -1 - r.Int63()
+ if got := apm.Load(key); got != nil {
+ t.Errorf("Load(%d): got %p, wanted nil", key, got)
+ }
+ }
+ }()
+
+ // This goroutine repeatedly calls RangeRepeatable() and checks that each
+ // key corresponds to an expected value.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ abort := false
+ for !abort && ctx.Err() == nil {
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("RangeRepeatable: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ return true
+ })
+ }
+ }()
+
+ // Finally, the main goroutine spins for the length of the test calling
+ // Range() and checking that each key that it observes is unique and
+ // corresponds to an expected value.
+ seenKeys := make(map[int64]struct{})
+ const testDuration = 5 * time.Second
+ end := time.Now().Add(testDuration)
+ abort := false
+ for time.Now().Before(end) {
+ apm.Range(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("Range: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ if _, ok := seenKeys[key]; ok {
+ t.Errorf("Range: got duplicate key %d", key)
+ abort = true
+ return false
+ }
+ seenKeys[key] = struct{}{}
+ return true
+ })
+ if abort {
+ break
+ }
+ for k := range seenKeys {
+ delete(seenKeys, k)
+ }
+ }
+}
+
+type benchmarkableMap interface {
+ Load(key int64) *testValue
+ Store(key int64, val *testValue)
+ LoadOrStore(key int64, val *testValue) (*testValue, bool)
+ Delete(key int64)
+}
+
+// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map.
+type rwMutexMap struct {
+ mu sync.RWMutex
+ m map[int64]*testValue
+}
+
+func (m *rwMutexMap) Load(key int64) *testValue {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.m[key]
+}
+
+func (m *rwMutexMap) Store(key int64, val *testValue) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ m.m[key] = val
+}
+
+func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ if oldVal, ok := m.m[key]; ok {
+ return oldVal, true
+ }
+ m.m[key] = val
+ return val, false
+}
+
+func (m *rwMutexMap) Delete(key int64) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ delete(m.m, key)
+}
+
+// syncMap implements benchmarkableMap for a sync.Map.
+type syncMap struct {
+ m sync.Map
+}
+
+func (m *syncMap) Load(key int64) *testValue {
+ val, ok := m.m.Load(key)
+ if !ok {
+ return nil
+ }
+ return val.(*testValue)
+}
+
+func (m *syncMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ actual, loaded := m.m.LoadOrStore(key, val)
+ return actual.(*testValue), loaded
+}
+
+func (m *syncMap) Delete(key int64) {
+ m.m.Delete(key)
+}
+
+// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap.
+type benchmarkableAtomicPtrMap struct {
+ m testAtomicPtrMap
+}
+
+func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMap) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded.
+type benchmarkableAtomicPtrMapSharded struct {
+ m testAtomicPtrMapSharded
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+var mapImpls = [...]struct {
+ name string
+ ctor func() benchmarkableMap
+}{
+ {
+ name: "RWMutexMap",
+ ctor: func() benchmarkableMap {
+ return new(rwMutexMap)
+ },
+ },
+ {
+ name: "SyncMap",
+ ctor: func() benchmarkableMap {
+ return new(syncMap)
+ },
+ },
+ {
+ name: "AtomicPtrMap",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMap)
+ },
+ },
+ {
+ name: "AtomicPtrMapSharded",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMapSharded)
+ },
+ },
+}
+
+func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.LoadOrStore(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkLoadOrStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLoadOrStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(i))
+ }
+}
+
+func BenchmarkLookupPositive(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupPositive(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(-1 - i))
+ }
+}
+
+func BenchmarkLookupNegative(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupNegative(b, mapImpl.ctor)
+ })
+ }
+}
+
+type benchmarkConcurrentOptions struct {
+ // loadsPerMutationPair is the number of map lookups between each
+ // insertion/deletion pair.
+ loadsPerMutationPair int
+
+ // If changeKeys is true, the keys used by each goroutine change between
+ // iterations of the test.
+ changeKeys bool
+}
+
+func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) {
+ var (
+ started sync.WaitGroup
+ workers sync.WaitGroup
+ )
+ started.Add(1)
+
+ m := mapCtor()
+ val := &testValue{}
+ // Insert a large number of unused elements into the map so that used
+ // elements are distributed throughout memory.
+ for i := 0; i < 10000; i++ {
+ m.Store(int64(-1-i), val)
+ }
+ // n := ceil(b.N / (opts.loadsPerMutationPair + 2))
+ n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2)
+ for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ {
+ workerID := i
+ workers.Add(1)
+ go func() {
+ defer workers.Done()
+ started.Wait()
+ for i := 0; i < n; i++ {
+ var key int64
+ if opts.changeKeys {
+ key = int64(workerID*n + i)
+ } else {
+ key = int64(workerID)
+ }
+ m.LoadOrStore(key, val)
+ for j := 0; j < opts.loadsPerMutationPair; j++ {
+ m.Load(key)
+ }
+ m.Delete(key)
+ }
+ }()
+ }
+
+ b.ResetTimer()
+ started.Done()
+ workers.Wait()
+}
+
+func BenchmarkConcurrent(b *testing.B) {
+ changeKeysChoices := [...]struct {
+ name string
+ val bool
+ }{
+ {"FixedKeys", false},
+ {"ChangingKeys", true},
+ }
+ writePcts := [...]struct {
+ name string
+ loadsPerMutationPair int
+ }{
+ {"1PercentWrites", 198},
+ {"10PercentWrites", 18},
+ {"50PercentWrites", 2},
+ }
+ for _, changeKeys := range changeKeysChoices {
+ for _, writePct := range writePcts {
+ for _, mapImpl := range mapImpls {
+ name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name)
+ b.Run(name, func(b *testing.B) {
+ benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{
+ loadsPerMutationPair: writePct.loadsPerMutationPair,
+ changeKeys: changeKeys.val,
+ })
+ })
+ }
+ }
+ }
+}
diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/generic_atomicptr_unsafe.go
index 525c4beed..82b6df18c 100644
--- a/pkg/sync/atomicptr_unsafe.go
+++ b/pkg/sync/generic_atomicptr_unsafe.go
@@ -3,9 +3,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package template doesn't exist. This file must be instantiated using the
+// Package seqatomic doesn't exist. This file must be instantiated using the
// go_template_instance rule in tools/go_generics/defs.bzl.
-package template
+package seqatomic
import (
"sync/atomic"
diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/generic_atomicptrmap_unsafe.go
new file mode 100644
index 000000000..c70dda6dd
--- /dev/null
+++ b/pkg/sync/generic_atomicptrmap_unsafe.go
@@ -0,0 +1,503 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package atomicptrmap doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package atomicptrmap
+
+import (
+ "reflect"
+ "runtime"
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Key is a required type parameter.
+type Key struct{}
+
+// Value is a required type parameter.
+type Value struct{}
+
+const (
+ // ShardOrder is an optional parameter specifying the base-2 log of the
+ // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce
+ // unnecessary synchronization between unrelated concurrent operations,
+ // improving performance for write-heavy workloads, but increase memory
+ // usage for small maps.
+ ShardOrder = 0
+)
+
+// Hasher is an optional type parameter. If Hasher is provided, it must define
+// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps.
+type Hasher struct {
+ defaultHasher
+}
+
+// defaultHasher is the default Hasher. This indirection exists because
+// defaultHasher must exist even if a custom Hasher is provided, to prevent the
+// Go compiler from complaining about defaultHasher's unused imports.
+type defaultHasher struct {
+ fn func(unsafe.Pointer, uintptr) uintptr
+ seed uintptr
+}
+
+// Init initializes the Hasher.
+func (h *defaultHasher) Init() {
+ h.fn = sync.MapKeyHasher(map[Key]*Value(nil))
+ h.seed = sync.RandUintptr()
+}
+
+// Hash returns the hash value for the given Key.
+func (h *defaultHasher) Hash(key Key) uintptr {
+ return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed)
+}
+
+var hasher Hasher
+
+func init() {
+ hasher.Init()
+}
+
+// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are
+// safe for concurrent use from multiple goroutines without additional
+// synchronization.
+//
+// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for
+// use. AtomicPtrMaps must not be copied after first use.
+//
+// sync.Map may be faster than AtomicPtrMap if most operations on the map are
+// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in
+// other circumstances.
+type AtomicPtrMap struct {
+ // AtomicPtrMap is implemented as a hash table with the following
+ // properties:
+ //
+ // * Collisions are resolved with quadratic probing. Of the two major
+ // alternatives, Robin Hood linear probing makes it difficult for writers
+ // to execute in parallel, and bucketing is less effective in Go due to
+ // lack of SIMD.
+ //
+ // * The table is optionally divided into shards indexed by hash to further
+ // reduce unnecessary synchronization.
+
+ shards [1 << ShardOrder]apmShard
+}
+
+func (m *AtomicPtrMap) shard(hash uintptr) *apmShard {
+ // Go defines right shifts >= width of shifted unsigned operand as 0, so
+ // this is correct even if ShardOrder is 0 (although nogo complains because
+ // nogo is dumb).
+ const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder
+ index := hash >> indexLSB
+ return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{}))))
+}
+
+type apmShard struct {
+ apmShardMutationData
+ _ [apmShardMutationDataPadding]byte
+ apmShardLookupData
+ _ [apmShardLookupDataPadding]byte
+}
+
+type apmShardMutationData struct {
+ dirtyMu sync.Mutex // serializes slot transitions out of empty
+ dirty uintptr // # slots with val != nil
+ count uintptr // # slots with val != nil and val != tombstone()
+ rehashMu sync.Mutex // serializes rehashing
+}
+
+type apmShardLookupData struct {
+ seq sync.SeqCount // allows atomic reads of slots+mask
+ slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq
+ mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq
+}
+
+const (
+ cacheLineBytes = 64
+ // Cache line padding is enabled if sharding is.
+ apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise
+ // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) %
+ // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes).
+ apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1)
+ apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding
+ apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1)
+ apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding
+
+ // These define fractional thresholds for when apmShard.rehash() is called
+ // (i.e. the load factor) and when it rehases to a larger table
+ // respectively. They are chosen such that the rehash threshold = the
+ // expansion threshold + 1/2, so that when reuse of deleted slots is rare
+ // or non-existent, rehashing occurs after the insertion of at least 1/2
+ // the table's size in new entries, which is acceptably infrequent.
+ apmRehashThresholdNum = 2
+ apmRehashThresholdDen = 3
+ apmExpansionThresholdNum = 1
+ apmExpansionThresholdDen = 6
+)
+
+type apmSlot struct {
+ // slot states are indicated by val:
+ //
+ // * Empty: val == nil; key is meaningless. May transition to full or
+ // evacuated with dirtyMu locked.
+ //
+ // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val
+ // is the Value mapped to key. May transition to deleted or evacuated.
+ //
+ // * Deleted: val == tombstone(); key is still immutable. key is mapped to
+ // no Value. May transition to full or evacuated.
+ //
+ // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on
+ // slots that have already been moved, requiring readers to wait for
+ // rehashing to complete and use the new table. Terminal state.
+ //
+ // Note that once val is non-nil, it cannot become nil again. That is, the
+ // transition from empty to non-empty is irreversible for a given slot;
+ // the only way to create more empty slots is by rehashing.
+ val unsafe.Pointer
+ key Key
+}
+
+func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot {
+ return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{})))
+}
+
+var tombstoneObj byte
+
+func tombstone() unsafe.Pointer {
+ return unsafe.Pointer(&tombstoneObj)
+}
+
+var evacuatedObj byte
+
+func evacuated() unsafe.Pointer {
+ return unsafe.Pointer(&evacuatedObj)
+}
+
+// Load returns the Value stored in m for key.
+func (m *AtomicPtrMap) Load(key Key) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ return nil
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Empty slot; end of probe sequence.
+ return nil
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// Store stores the Value val for key.
+func (m *AtomicPtrMap) Store(key Key, val *Value) {
+ m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// Swap stores the Value val for key and returns the previously-mapped Value.
+func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value {
+ return m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it
+// stores the Value newVal for key. CompareAndSwap returns the previous Value
+// stored for key, whether or not it stores newVal.
+func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value {
+ return m.maybeCompareAndSwap(key, true, oldVal, newVal)
+}
+
+func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+ oldVal := tombstone()
+ if typedOldVal != nil {
+ oldVal = unsafe.Pointer(typedOldVal)
+ }
+ newVal := tombstone()
+ if typedNewVal != nil {
+ newVal = unsafe.Pointer(typedNewVal)
+ }
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Need to allocate a table before insertion.
+ shard.rehash(nil)
+ goto retry
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Try to grab this slot for ourselves.
+ shard.dirtyMu.Lock()
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Check if we need to rehash before dirtying a slot.
+ if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum {
+ shard.dirtyMu.Unlock()
+ shard.rehash(slots)
+ goto retry
+ }
+ slot.key = key
+ atomic.StorePointer(&slot.val, newVal) // transitions slot to full
+ shard.dirty++
+ atomic.AddUintptr(&shard.count, 1)
+ shard.dirtyMu.Unlock()
+ return nil
+ }
+ // Raced with another store; the slot is no longer empty. Continue
+ // with the new value of slotVal since we may have raced with
+ // another store of key.
+ shard.dirtyMu.Unlock()
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ // We're reusing an existing slot, so rehashing isn't necessary.
+ for {
+ if (compare && oldVal != slotVal) || newVal == slotVal {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) {
+ if slotVal == tombstone() {
+ atomic.AddUintptr(&shard.count, 1)
+ return nil
+ }
+ if newVal == tombstone() {
+ atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */)
+ }
+ return (*Value)(slotVal)
+ }
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ }
+ }
+ // This produces a triangular number sequence of offsets from the
+ // initially-probed position.
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// rehash is marked nosplit to avoid preemption during table copying.
+//go:nosplit
+func (shard *apmShard) rehash(oldSlots unsafe.Pointer) {
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+
+ if shard.slots != oldSlots {
+ // Raced with another call to rehash().
+ return
+ }
+
+ // Determine the size of the new table. Constraints:
+ //
+ // * The size of the table must be a power of two to ensure that every slot
+ // is visitable by every probe sequence under quadratic probing with
+ // triangular numbers.
+ //
+ // * The size of the table cannot decrease because even if shard.count is
+ // currently smaller than shard.dirty, concurrent stores that reuse
+ // existing slots can drive shard.count back up to a maximum of
+ // shard.dirty.
+ newSize := uintptr(8) // arbitrary initial size
+ if oldSlots != nil {
+ oldSize := shard.mask + 1
+ newSize = oldSize
+ if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum {
+ newSize *= 2
+ }
+ }
+
+ // Allocate the new table.
+ newSlotsSlice := make([]apmSlot, newSize)
+ newSlotsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&newSlotsSlice))
+ newSlots := unsafe.Pointer(newSlotsReflect.Data)
+ runtime.KeepAlive(newSlotsSlice)
+ newMask := newSize - 1
+
+ // Start a writer critical section now so that racing users of the old
+ // table that observe evacuated() wait for the new table. (But lock dirtyMu
+ // first since doing so may block, which we don't want to do during the
+ // writer critical section.)
+ shard.dirtyMu.Lock()
+ shard.seq.BeginWrite()
+
+ if oldSlots != nil {
+ realCount := uintptr(0)
+ // Copy old entries to the new table.
+ oldMask := shard.mask
+ for i := uintptr(0); i <= oldMask; i++ {
+ oldSlot := apmSlotAt(oldSlots, i)
+ val := atomic.SwapPointer(&oldSlot.val, evacuated())
+ if val == nil || val == tombstone() {
+ continue
+ }
+ hash := hasher.Hash(oldSlot.key)
+ j := hash & newMask
+ inc := uintptr(1)
+ for {
+ newSlot := apmSlotAt(newSlots, j)
+ if newSlot.val == nil {
+ newSlot.val = val
+ newSlot.key = oldSlot.key
+ break
+ }
+ j = (j + inc) & newMask
+ inc++
+ }
+ realCount++
+ }
+ // Update dirty to reflect that tombstones were not copied to the new
+ // table. Use realCount since a concurrent mutator may not have updated
+ // shard.count yet.
+ shard.dirty = realCount
+ }
+
+ // Switch to the new table.
+ atomic.StorePointer(&shard.slots, newSlots)
+ atomic.StoreUintptr(&shard.mask, newMask)
+
+ shard.seq.EndWrite()
+ shard.dirtyMu.Unlock()
+}
+
+// Range invokes f on each Key-Value pair stored in m. If any call to f returns
+// false, Range stops iteration and returns.
+//
+// Range does not necessarily correspond to any consistent snapshot of the
+// Map's contents: no Key will be visited more than once, but if the Value for
+// any Key is stored or deleted concurrently, Range may reflect any mapping for
+// that Key from any point during the Range call.
+//
+// f must not call other methods on m.
+func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+ if !shard.doRange(f) {
+ return
+ }
+ }
+}
+
+func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool {
+ // We have to lock rehashMu because if we handled races with rehashing by
+ // retrying, f could see the same key twice.
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+ slots := shard.slots
+ if slots == nil {
+ return true
+ }
+ mask := shard.mask
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return false
+ }
+ }
+ return true
+}
+
+// RangeRepeatable is like Range, but:
+//
+// * RangeRepeatable may visit the same Key multiple times in the presence of
+// concurrent mutators, possibly passing different Values to f in different
+// calls.
+//
+// * It is safe for f to call other methods on m.
+func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+
+ retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ continue
+ }
+
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return
+ }
+ }
+ }
+}
diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go
index 2184cb5ab..82b676abf 100644
--- a/pkg/sync/seqatomic_unsafe.go
+++ b/pkg/sync/generic_seqatomic_unsafe.go
@@ -3,25 +3,17 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package template doesn't exist. This file must be instantiated using the
+// Package seqatomic doesn't exist. This file must be instantiated using the
// go_template_instance rule in tools/go_generics/defs.bzl.
-package template
+package seqatomic
import (
- "fmt"
- "reflect"
- "strings"
"unsafe"
"gvisor.dev/gvisor/pkg/sync"
)
// Value is a required type parameter.
-//
-// Value must not contain any pointers, including interface objects, function
-// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs
-// containing any of the above. An init() function will panic if this property
-// does not hold.
type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
@@ -55,12 +47,3 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value)
ok = seq.ReadOk(epoch)
return
}
-
-func init() {
- var val Value
- typ := reflect.TypeOf(val)
- name := typ.Name()
- if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
- panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
- }
-}
diff --git a/pkg/sync/spin_legacy_unsafe.go b/pkg/sync/goyield_go113_unsafe.go
index 61fc7320e..8aee0d455 100644
--- a/pkg/sync/spin_legacy_unsafe.go
+++ b/pkg/sync/goyield_go113_unsafe.go
@@ -10,15 +10,8 @@ package sync
import (
"runtime"
- _ "unsafe" // for go:linkname
)
-//go:linkname canSpin sync.runtime_canSpin
-func canSpin(i int) bool
-
-//go:linkname doSpin sync.runtime_doSpin
-func doSpin()
-
func goyield() {
// goyield is not available until Go 1.14.
runtime.Gosched()
diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/goyield_unsafe.go
index 18e8fc743..672ee274d 100644
--- a/pkg/sync/spin_unsafe.go
+++ b/pkg/sync/goyield_unsafe.go
@@ -14,11 +14,5 @@ import (
_ "unsafe" // for go:linkname
)
-//go:linkname canSpin sync.runtime_canSpin
-func canSpin(i int) bool
-
-//go:linkname doSpin sync.runtime_doSpin
-func doSpin()
-
//go:linkname goyield runtime.goyield
func goyield()
diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
deleted file mode 100644
index f5e630009..000000000
--- a/pkg/sync/memmove_unsafe.go
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
-package sync
-
-import (
- "unsafe"
-)
-
-//go:linkname memmove runtime.memmove
-//go:noescape
-func memmove(to, from unsafe.Pointer, n uintptr)
-
-// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't
-// define it because go_generics can't update the go:linkname annotation.
-// Furthermore, go:linkname silently doesn't work if the local name is exported
-// (this is of course undocumented), which is why this indirection is
-// necessary.
-func Memmove(to, from unsafe.Pointer, n uintptr) {
- memmove(to, from, n)
-}
diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go
index 006055dd6..70b5f3a5e 100644
--- a/pkg/sync/norace_unsafe.go
+++ b/pkg/sync/norace_unsafe.go
@@ -8,6 +8,7 @@
package sync
import (
+ "sync/atomic"
"unsafe"
)
@@ -33,3 +34,13 @@ func RaceRelease(addr unsafe.Pointer) {
// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge.
func RaceReleaseMerge(addr unsafe.Pointer) {
}
+
+// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to
+// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector.
+// This is necessary when implementing gopark callbacks, since no race context
+// is available during their execution.
+func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool {
+ // Use atomic.CompareAndSwapUintptr outside of race builds for
+ // inlinability.
+ return atomic.CompareAndSwapUintptr(ptr, old, new)
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/sync/race_amd64.s
index 5e216b045..57bc0ec79 100644
--- a/pkg/syncevent/waiter_amd64.s
+++ b/pkg/sync/race_amd64.s
@@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+// +build amd64
+
#include "textflag.h"
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-//
-// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
-TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
+TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25
MOVQ ptr+0(FP), DI
- MOVQ wg+8(FP), SI
+ MOVQ old+8(FP), AX
+ MOVQ new+16(FP), SI
- MOVQ $·preparingG(SB), AX
LOCK
- CMPXCHGQ DI, 0(SI)
+ CMPXCHGQ SI, 0(DI)
SETEQ AX
- MOVB AX, ret+16(FP)
+ MOVB AX, ret+24(FP)
RET
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/sync/race_arm64.s
index f4c06f194..88f091fda 100644
--- a/pkg/syncevent/waiter_arm64.s
+++ b/pkg/sync/race_arm64.s
@@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+// +build arm64
+
#include "textflag.h"
-// See waiter_noasm_unsafe.go for a description of waiterUnlock.
-//
-// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
-TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
- MOVD wg+8(FP), R0
- MOVD $·preparingG(SB), R1
- MOVD ptr+0(FP), R2
+// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
+TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25
+ MOVD ptr+0(FP), R0
+ MOVD old+8(FP), R1
+ MOVD new+16(FP), R1
again:
LDAXR (R0), R3
CMP R1, R3
@@ -29,6 +30,6 @@ again:
CBNZ R3, again
ok:
CSET EQ, R0
- MOVB R0, ret+16(FP)
+ MOVB R0, ret+24(FP)
RET
diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go
index 31d8fa9a6..59985c270 100644
--- a/pkg/sync/race_unsafe.go
+++ b/pkg/sync/race_unsafe.go
@@ -39,3 +39,9 @@ func RaceRelease(addr unsafe.Pointer) {
func RaceReleaseMerge(addr unsafe.Pointer) {
runtime.RaceReleaseMerge(addr)
}
+
+// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to
+// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector.
+// This is necessary when implementing gopark callbacks, since no race context
+// is available during their execution.
+func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
new file mode 100644
index 000000000..e925e2e5b
--- /dev/null
+++ b/pkg/sync/runtime_unsafe.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.17
+
+// Check function signatures and constants when updating Go version.
+
+package sync
+
+import (
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// Note that go:linkname silently doesn't work if the local name is exported,
+// necessitating an indirection for exported functions.
+
+// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
+//
+//go:nosplit
+func Memmove(to, from unsafe.Pointer, n uintptr) {
+ memmove(to, from, n)
+}
+
+//go:linkname memmove runtime.memmove
+//go:noescape
+func memmove(to, from unsafe.Pointer, n uintptr)
+
+// Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock);
+// if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g)
+// is called. unlockf and its callees must be nosplit and norace, since stack
+// splitting and race context are not available where it is called.
+//
+//go:nosplit
+func Gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int) {
+ gopark(unlockf, lock, reason, traceEv, traceskip)
+}
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+// Goready is runtime.goready.
+//
+//go:nosplit
+func Goready(gp uintptr, traceskip int) {
+ goready(gp, traceskip)
+}
+
+//go:linkname goready runtime.goready
+func goready(gp uintptr, traceskip int)
+
+// Values for the reason argument to gopark, from Go's src/runtime/runtime2.go.
+const (
+ WaitReasonSelect uint8 = 9
+)
+
+// Values for the traceEv argument to gopark, from Go's src/runtime/trace.go.
+const (
+ TraceEvGoBlockSelect byte = 24
+)
+
+// Rand32 returns a non-cryptographically-secure random uint32.
+func Rand32() uint32 {
+ return fastrand()
+}
+
+// Rand64 returns a non-cryptographically-secure random uint64.
+func Rand64() uint64 {
+ return uint64(fastrand())<<32 | uint64(fastrand())
+}
+
+//go:linkname fastrand runtime.fastrand
+func fastrand() uint32
+
+// RandUintptr returns a non-cryptographically-secure random uintptr.
+func RandUintptr() uintptr {
+ if unsafe.Sizeof(uintptr(0)) == 4 {
+ return uintptr(Rand32())
+ }
+ return uintptr(Rand64())
+}
+
+// MapKeyHasher returns a hash function for pointers of m's key type.
+//
+// Preconditions: m must be a map.
+func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr {
+ if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map {
+ panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp))
+ }
+ mtyp := *(**maptype)(unsafe.Pointer(&m))
+ return mtyp.hasher
+}
+
+type maptype struct {
+ size uintptr
+ ptrdata uintptr
+ hash uint32
+ tflag uint8
+ align uint8
+ fieldAlign uint8
+ kind uint8
+ equal func(unsafe.Pointer, unsafe.Pointer) bool
+ gcdata *byte
+ str int32
+ ptrToThis int32
+ key unsafe.Pointer
+ elem unsafe.Pointer
+ bucket unsafe.Pointer
+ hasher func(unsafe.Pointer, uintptr) uintptr
+ // more fields
+}
+
+// These functions are only used within the sync package.
+
+//go:linkname semacquire sync.runtime_Semacquire
+func semacquire(s *uint32)
+
+//go:linkname semrelease sync.runtime_Semrelease
+func semrelease(s *uint32, handoff bool, skipframes int)
+
+//go:linkname canSpin sync.runtime_canSpin
+func canSpin(i int) bool
+
+//go:linkname doSpin sync.runtime_doSpin
+func doSpin()
diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go
index ce667e825..5ca96d12b 100644
--- a/pkg/sync/rwmutex_test.go
+++ b/pkg/sync/rwmutex_test.go
@@ -102,7 +102,7 @@ func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone c
}
for i := 0; i < 100; i++ {
}
- n = atomic.AddInt32(activity, -1)
+ atomic.AddInt32(activity, -1)
rwm.RUnlock()
}
cdone <- true
diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index fa023f5bb..4cf3fcd6e 100644
--- a/pkg/sync/rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -3,11 +3,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build go1.13
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
// This is mostly copied from the standard library's sync/rwmutex.go.
//
// Happens-before relationships indicated to the race detector:
@@ -23,15 +18,9 @@ import (
"unsafe"
)
-//go:linkname runtimeSemacquire sync.runtime_Semacquire
-func runtimeSemacquire(s *uint32)
-
-//go:linkname runtimeSemrelease sync.runtime_Semrelease
-func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
-
-// RWMutex is identical to sync.RWMutex, but adds the DowngradeLock,
-// TryLock and TryRLock methods.
-type RWMutex struct {
+// CrossGoroutineRWMutex is equivalent to RWMutex, but it need not be unlocked
+// by a the same goroutine that locked the mutex.
+type CrossGoroutineRWMutex struct {
// w is held if there are pending writers
//
// We use CrossGoroutineMutex rather than Mutex because the lock
@@ -48,7 +37,7 @@ const rwmutexMaxReaders = 1 << 30
// TryRLock locks rw for reading. It returns true if it succeeds and false
// otherwise. It does not block.
-func (rw *RWMutex) TryRLock() bool {
+func (rw *CrossGoroutineRWMutex) TryRLock() bool {
if RaceEnabled {
RaceDisable()
}
@@ -72,13 +61,17 @@ func (rw *RWMutex) TryRLock() bool {
}
// RLock locks rw for reading.
-func (rw *RWMutex) RLock() {
+//
+// It should not be used for recursive read locking; a blocked Lock call
+// excludes new readers from acquiring the lock. See the documentation on the
+// RWMutex type.
+func (rw *CrossGoroutineRWMutex) RLock() {
if RaceEnabled {
RaceDisable()
}
if atomic.AddInt32(&rw.readerCount, 1) < 0 {
// A writer is pending, wait for it.
- runtimeSemacquire(&rw.readerSem)
+ semacquire(&rw.readerSem)
}
if RaceEnabled {
RaceEnable()
@@ -87,7 +80,10 @@ func (rw *RWMutex) RLock() {
}
// RUnlock undoes a single RLock call.
-func (rw *RWMutex) RUnlock() {
+//
+// Preconditions:
+// * rw is locked for reading.
+func (rw *CrossGoroutineRWMutex) RUnlock() {
if RaceEnabled {
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
RaceDisable()
@@ -99,7 +95,7 @@ func (rw *RWMutex) RUnlock() {
// A writer is pending.
if atomic.AddInt32(&rw.readerWait, -1) == 0 {
// The last reader unblocks the writer.
- runtimeSemrelease(&rw.writerSem, false, 0)
+ semrelease(&rw.writerSem, false, 0)
}
}
if RaceEnabled {
@@ -109,7 +105,7 @@ func (rw *RWMutex) RUnlock() {
// TryLock locks rw for writing. It returns true if it succeeds and false
// otherwise. It does not block.
-func (rw *RWMutex) TryLock() bool {
+func (rw *CrossGoroutineRWMutex) TryLock() bool {
if RaceEnabled {
RaceDisable()
}
@@ -135,8 +131,9 @@ func (rw *RWMutex) TryLock() bool {
return true
}
-// Lock locks rw for writing.
-func (rw *RWMutex) Lock() {
+// Lock locks rw for writing. If the lock is already locked for reading or
+// writing, Lock blocks until the lock is available.
+func (rw *CrossGoroutineRWMutex) Lock() {
if RaceEnabled {
RaceDisable()
}
@@ -146,7 +143,7 @@ func (rw *RWMutex) Lock() {
r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders
// Wait for active readers.
if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 {
- runtimeSemacquire(&rw.writerSem)
+ semacquire(&rw.writerSem)
}
if RaceEnabled {
RaceEnable()
@@ -155,7 +152,10 @@ func (rw *RWMutex) Lock() {
}
// Unlock unlocks rw for writing.
-func (rw *RWMutex) Unlock() {
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *CrossGoroutineRWMutex) Unlock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.writerSem))
RaceRelease(unsafe.Pointer(&rw.readerSem))
@@ -168,7 +168,7 @@ func (rw *RWMutex) Unlock() {
}
// Unblock blocked readers, if any.
for i := 0; i < int(r); i++ {
- runtimeSemrelease(&rw.readerSem, false, 0)
+ semrelease(&rw.readerSem, false, 0)
}
// Allow other writers to proceed.
rw.w.Unlock()
@@ -178,7 +178,10 @@ func (rw *RWMutex) Unlock() {
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
-func (rw *RWMutex) DowngradeLock() {
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *CrossGoroutineRWMutex) DowngradeLock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.readerSem))
RaceDisable()
@@ -191,7 +194,7 @@ func (rw *RWMutex) DowngradeLock() {
// Unblock blocked readers, if any. Note that this loop starts as 1 since r
// includes this goroutine.
for i := 1; i < int(r); i++ {
- runtimeSemrelease(&rw.readerSem, false, 0)
+ semrelease(&rw.readerSem, false, 0)
}
// Allow other writers to proceed to rw.w.Lock(). Note that they will still
// block on rw.writerSem since at least this reader exists, such that
@@ -201,3 +204,91 @@ func (rw *RWMutex) DowngradeLock() {
RaceEnable()
}
}
+
+// A RWMutex is a reader/writer mutual exclusion lock. The lock can be held by
+// an arbitrary number of readers or a single writer. The zero value for a
+// RWMutex is an unlocked mutex.
+//
+// A RWMutex must not be copied after first use.
+//
+// If a goroutine holds a RWMutex for reading and another goroutine might call
+// Lock, no goroutine should expect to be able to acquire a read lock until the
+// initial read lock is released. In particular, this prohibits recursive read
+// locking. This is to ensure that the lock eventually becomes available; a
+// blocked Lock call excludes new readers from acquiring the lock.
+//
+// A Mutex must be unlocked by the same goroutine that locked it. This
+// invariant is enforced with the 'checklocks' build tag.
+type RWMutex struct {
+ m CrossGoroutineRWMutex
+}
+
+// TryRLock locks rw for reading. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryRLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(rw))
+ locked := rw.m.TryRLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(rw))
+ }
+ return locked
+}
+
+// RLock locks rw for reading.
+//
+// It should not be used for recursive read locking; a blocked Lock call
+// excludes new readers from acquiring the lock. See the documentation on the
+// RWMutex type.
+func (rw *RWMutex) RLock() {
+ noteLock(unsafe.Pointer(rw))
+ rw.m.RLock()
+}
+
+// RUnlock undoes a single RLock call.
+//
+// Preconditions:
+// * rw is locked for reading.
+// * rw was locked by this goroutine.
+func (rw *RWMutex) RUnlock() {
+ rw.m.RUnlock()
+ noteUnlock(unsafe.Pointer(rw))
+}
+
+// TryLock locks rw for writing. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryLock() bool {
+ // Note lock first to enforce proper locking even if unsuccessful.
+ noteLock(unsafe.Pointer(rw))
+ locked := rw.m.TryLock()
+ if !locked {
+ noteUnlock(unsafe.Pointer(rw))
+ }
+ return locked
+}
+
+// Lock locks rw for writing. If the lock is already locked for reading or
+// writing, Lock blocks until the lock is available.
+func (rw *RWMutex) Lock() {
+ noteLock(unsafe.Pointer(rw))
+ rw.m.Lock()
+}
+
+// Unlock unlocks rw for writing.
+//
+// Preconditions:
+// * rw is locked for writing.
+// * rw was locked by this goroutine.
+func (rw *RWMutex) Unlock() {
+ rw.m.Unlock()
+ noteUnlock(unsafe.Pointer(rw))
+}
+
+// DowngradeLock atomically unlocks rw for writing and locks it for reading.
+//
+// Preconditions:
+// * rw is locked for writing.
+func (rw *RWMutex) DowngradeLock() {
+ // No note change for DowngradeLock.
+ rw.m.DowngradeLock()
+}
diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go
index 2c5d3df99..1f025f33c 100644
--- a/pkg/sync/seqcount.go
+++ b/pkg/sync/seqcount.go
@@ -6,8 +6,6 @@
package sync
import (
- "fmt"
- "reflect"
"sync/atomic"
)
@@ -27,9 +25,6 @@ import (
// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other
// operations to be made atomic with reads of SeqCount-protected data.
//
-// - SeqCount may be less flexible: as of this writing, SeqCount-protected data
-// cannot include pointers.
-//
// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected
// data require instantiating function templates using go_generics (see
// seqatomic.go).
@@ -128,32 +123,3 @@ func (s *SeqCount) EndWrite() {
panic("SeqCount.EndWrite outside writer critical section")
}
}
-
-// PointersInType returns a list of pointers reachable from values named
-// valName of the given type.
-//
-// PointersInType is not exhaustive, but it is guaranteed that if typ contains
-// at least one pointer, then PointersInTypeOf returns a non-empty list.
-func PointersInType(typ reflect.Type, valName string) []string {
- switch kind := typ.Kind(); kind {
- case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
- return nil
-
- case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer:
- return []string{valName}
-
- case reflect.Array:
- return PointersInType(typ.Elem(), valName+"[]")
-
- case reflect.Struct:
- var ptrs []string
- for i, n := 0, typ.NumField(); i < n; i++ {
- field := typ.Field(i)
- ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...)
- }
- return ptrs
-
- default:
- return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)}
- }
-}
diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go
index 6eb7b4b59..3f5592e3e 100644
--- a/pkg/sync/seqcount_test.go
+++ b/pkg/sync/seqcount_test.go
@@ -6,7 +6,6 @@
package sync
import (
- "reflect"
"testing"
"time"
)
@@ -99,55 +98,3 @@ func BenchmarkSeqCountReadUncontended(b *testing.B) {
}
})
}
-
-func TestPointersInType(t *testing.T) {
- for _, test := range []struct {
- name string // used for both test and value name
- val interface{}
- ptrs []string
- }{
- {
- name: "EmptyStruct",
- val: struct{}{},
- },
- {
- name: "Int",
- val: int(0),
- },
- {
- name: "MixedStruct",
- val: struct {
- b bool
- I int
- ExportedPtr *struct{}
- unexportedPtr *struct{}
- arr [2]int
- ptrArr [2]*int
- nestedStruct struct {
- nestedNonptr int
- nestedPtr *int
- }
- structArr [1]struct {
- nonptr int
- ptr *int
- }
- }{},
- ptrs: []string{
- "MixedStruct.ExportedPtr",
- "MixedStruct.unexportedPtr",
- "MixedStruct.ptrArr[]",
- "MixedStruct.nestedStruct.nestedPtr",
- "MixedStruct.structArr[].ptr",
- },
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- typ := reflect.TypeOf(test.val)
- ptrs := PointersInType(typ, test.name)
- t.Logf("Found pointers: %v", ptrs)
- if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) {
- t.Errorf("Got %v, wanted %v", ptrs, test.ptrs)
- }
- })
- }
-}
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
index 0500a22cf..42c553308 100644
--- a/pkg/syncevent/BUILD
+++ b/pkg/syncevent/BUILD
@@ -9,10 +9,6 @@ go_library(
"receiver.go",
"source.go",
"syncevent.go",
- "waiter_amd64.s",
- "waiter_arm64.s",
- "waiter_asm_unsafe.go",
- "waiter_noasm_unsafe.go",
"waiter_unsafe.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
deleted file mode 100644
index 0f74a689c..000000000
--- a/pkg/syncevent/waiter_noasm_unsafe.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// waiterUnlock is called from g0, so when the race detector is enabled,
-// waiterUnlock must be implemented in assembly since no race context is
-// available.
-//
-// +build !race
-// +build !amd64,!arm64
-
-package syncevent
-
-import (
- "sync/atomic"
- "unsafe"
-)
-
-// waiterUnlock is the "unlock function" passed to runtime.gopark by
-// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
-// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
-// should be aborted.
-//
-//go:nosplit
-func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool {
- // The only way this CAS can fail is if a call to Waiter.NotifyPending()
- // has replaced *wg with nil, in which case we should not sleep.
- return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), ptr)
-}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
index 518f18479..b6ed2852d 100644
--- a/pkg/syncevent/waiter_unsafe.go
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.11
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package syncevent
import (
@@ -26,17 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-//go:linkname gopark runtime.gopark
-func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
-
-//go:linkname goready runtime.goready
-func goready(g unsafe.Pointer, traceskip int)
-
-const (
- waitReasonSelect = 9 // Go: src/runtime/runtime2.go
- traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
-)
-
// Waiter allows a goroutine to block on pending events received by a Receiver.
//
// Waiter.Init() must be called before first use.
@@ -45,20 +29,19 @@ type Waiter struct {
// g is one of:
//
- // - nil: No goroutine is blocking in Wait.
+ // - 0: No goroutine is blocking in Wait.
//
- // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // - preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
// completed waiterUnlock(). Thus the wait can only be interrupted by
- // replacing the value of g with nil (the G may not be in state Gwaiting
- // yet, so we can't call goready.)
+ // replacing the value of g with 0 (the G may not be in state Gwaiting yet,
+ // so we can't call goready.)
//
// - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
// goroutine blocked in Wait, which can only be woken by calling goready.
- g unsafe.Pointer `state:"zerovalue"`
+ g uintptr `state:"zerovalue"`
}
-// Sentinel object for Waiter.g.
-var preparingG struct{}
+const preparingG = 1
// Init must be called before first use of w.
func (w *Waiter) Init() {
@@ -99,21 +82,29 @@ func (w *Waiter) WaitFor(es Set) Set {
}
// Indicate that we're preparing to go to sleep.
- atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+ atomic.StoreUintptr(&w.g, preparingG)
// If an event is pending, abort the sleep.
if p := w.r.Pending(); p&es != NoEvents {
- atomic.StorePointer(&w.g, nil)
+ atomic.StoreUintptr(&w.g, 0)
return p
}
// If w.g is still preparingG (i.e. w.NotifyPending() has not been
- // called or has not reached atomic.SwapPointer()), go to sleep until
+ // called or has not reached atomic.SwapUintptr()), go to sleep until
// w.NotifyPending() => goready().
- gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
}
}
+//go:norace
+//go:nosplit
+func waiterCommit(g uintptr, wg unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(wg), preparingG, g)
+}
+
// Ack marks the given events as not pending.
func (w *Waiter) Ack(es Set) {
w.r.Ack(es)
@@ -135,20 +126,20 @@ func (w *Waiter) WaitAndAckAll() Set {
for {
// Indicate that we're preparing to go to sleep.
- atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+ atomic.StoreUintptr(&w.g, preparingG)
// If an event is pending, abort the sleep.
if w.r.Pending() != NoEvents {
if p := w.r.PendingAndAckAll(); p != NoEvents {
- atomic.StorePointer(&w.g, nil)
+ atomic.StoreUintptr(&w.g, 0)
return p
}
}
// If w.g is still preparingG (i.e. w.NotifyPending() has not been
- // called or has not reached atomic.SwapPointer()), go to sleep until
+ // called or has not reached atomic.SwapUintptr()), go to sleep until
// w.NotifyPending() => goready().
- gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0)
// Check for pending events. We call PendingAndAckAll() directly now since
// we only expect to be woken after events become pending.
@@ -171,14 +162,14 @@ func (w *Waiter) NotifyPending() {
// goroutine. NotifyPending is called after w.r.Pending() is updated, so
// concurrent and future calls to w.Wait() will observe pending events and
// abort sleeping.
- if atomic.LoadPointer(&w.g) == nil {
+ if atomic.LoadUintptr(&w.g) == 0 {
return
}
// Wake a sleeping G, or prevent a G that is preparing to sleep from doing
// so. Swap is needed here to ensure that only one call to NotifyPending
// calls goready.
- if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
- goready(g, 0)
+ if g := atomic.SwapUintptr(&w.g, 0); g > preparingG {
+ sync.Goready(g, 0)
}
}
diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go
index fc6ef60a1..77faa3670 100644
--- a/pkg/syserr/host_linux.go
+++ b/pkg/syserr/host_linux.go
@@ -32,7 +32,7 @@ var linuxHostTranslations [maxErrno]linuxHostTranslation
// FromHost translates a syscall.Errno to a corresponding Error value.
func FromHost(err syscall.Errno) *Error {
- if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
+ if int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok {
panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err))
}
return linuxHostTranslations[err].err
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 5ae10939d..2756d4471 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -15,6 +15,8 @@
package syserr
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -46,47 +48,60 @@ var (
ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL)
ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES)
ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
+ ErrBadBuffer = New(tcpip.ErrBadBuffer.String(), linux.EFAULT)
)
-var netstackErrorTranslations = map[*tcpip.Error]*Error{
- tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
- tcpip.ErrUnknownNICID: ErrUnknownNICID,
- tcpip.ErrUnknownDevice: ErrUnknownDevice,
- tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
- tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
- tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
- tcpip.ErrNoRoute: ErrNoRoute,
- tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
- tcpip.ErrAlreadyBound: ErrAlreadyBound,
- tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
- tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
- tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
- tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
- tcpip.ErrPortInUse: ErrPortInUse,
- tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
- tcpip.ErrClosedForSend: ErrClosedForSend,
- tcpip.ErrClosedForReceive: ErrClosedForReceive,
- tcpip.ErrWouldBlock: ErrWouldBlock,
- tcpip.ErrConnectionRefused: ErrConnectionRefused,
- tcpip.ErrTimeout: ErrTimeout,
- tcpip.ErrAborted: ErrAborted,
- tcpip.ErrConnectStarted: ErrConnectStarted,
- tcpip.ErrDestinationRequired: ErrDestinationRequired,
- tcpip.ErrNotSupported: ErrNotSupported,
- tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
- tcpip.ErrNotConnected: ErrNotConnected,
- tcpip.ErrConnectionReset: ErrConnectionReset,
- tcpip.ErrConnectionAborted: ErrConnectionAborted,
- tcpip.ErrNoSuchFile: ErrNoSuchFile,
- tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
- tcpip.ErrNoLinkAddress: ErrHostDown,
- tcpip.ErrBadAddress: ErrBadAddress,
- tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
- tcpip.ErrMessageTooLong: ErrMessageTooLong,
- tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
- tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
- tcpip.ErrNotPermitted: ErrNotPermittedNet,
- tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported,
+var netstackErrorTranslations map[string]*Error
+
+func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) {
+ key := tcpipErr.String()
+ if _, ok := netstackErrorTranslations[key]; ok {
+ panic(fmt.Sprintf("duplicate error key: %s", key))
+ }
+ netstackErrorTranslations[key] = netstackErr
+}
+
+func init() {
+ netstackErrorTranslations = make(map[string]*Error)
+ addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol)
+ addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID)
+ addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice)
+ addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption)
+ addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID)
+ addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress)
+ addErrMapping(tcpip.ErrNoRoute, ErrNoRoute)
+ addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint)
+ addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound)
+ addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState)
+ addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting)
+ addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected)
+ addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable)
+ addErrMapping(tcpip.ErrPortInUse, ErrPortInUse)
+ addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress)
+ addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend)
+ addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive)
+ addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock)
+ addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused)
+ addErrMapping(tcpip.ErrTimeout, ErrTimeout)
+ addErrMapping(tcpip.ErrAborted, ErrAborted)
+ addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted)
+ addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired)
+ addErrMapping(tcpip.ErrNotSupported, ErrNotSupported)
+ addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported)
+ addErrMapping(tcpip.ErrNotConnected, ErrNotConnected)
+ addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset)
+ addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted)
+ addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile)
+ addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue)
+ addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown)
+ addErrMapping(tcpip.ErrBadAddress, ErrBadAddress)
+ addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable)
+ addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong)
+ addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace)
+ addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled)
+ addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet)
+ addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported)
+ addErrMapping(tcpip.ErrBadBuffer, ErrBadBuffer)
}
// TranslateNetstackError converts an error from the tcpip package to a sentry
@@ -95,7 +110,7 @@ func TranslateNetstackError(err *tcpip.Error) *Error {
if err == nil {
return nil
}
- se, ok := netstackErrorTranslations[err]
+ se, ok := netstackErrorTranslations[err.String()]
if !ok {
panic("Unknown error: " + err.String())
}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 27f96a3ac..89b765f1b 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,10 +1,24 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "sock_err_list",
+ out = "sock_err_list.go",
+ package = "tcpip",
+ prefix = "sockError",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*SockError",
+ "Linker": "*SockError",
+ },
+)
+
go_library(
name = "tcpip",
srcs = [
+ "sock_err_list.go",
"socketops.go",
"tcpip.go",
"time_unsafe.go",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 4f551cd92..7193f56ad 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -286,45 +286,47 @@ type opErrorer interface {
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
-func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
+func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) {
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
default:
}
- read, _, err := ep.Read(addr)
+ w := tcpip.SliceWriter(b)
+ opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
+ res, err := ep.Read(&w, len(b), opts)
if err == tcpip.ErrWouldBlock {
- if dontWait {
- return nil, errWouldBlock
- }
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- read, _, err = ep.Read(addr)
+ res, err = ep.Read(&w, len(b), opts)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
case <-notifyCh:
}
}
}
if err == tcpip.ErrClosedForReceive {
- return nil, io.EOF
+ return 0, io.EOF
}
if err != nil {
- return nil, errorer.newOpError("read", errors.New(err.String()))
+ return 0, errorer.newOpError("read", errors.New(err.String()))
}
- return read, nil
+ if addr != nil {
+ *addr = res.RemoteAddr
+ }
+ return res.Count, nil
}
// Read implements net.Conn.Read.
@@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) {
deadline := c.readCancel()
- numRead := 0
- defer func() {
- if numRead != 0 {
- c.ep.ModerateRecvBuf(numRead)
- }
- }()
- for numRead != len(b) {
- if len(c.read) == 0 {
- var err error
- c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
- if err != nil {
- if numRead != 0 {
- return numRead, nil
- }
- return numRead, err
- }
- }
- n := copy(b[numRead:], c.read)
- c.read.TrimFront(n)
- numRead += n
- if len(c.read) == 0 {
- c.read = nil
- }
+ n, err := commonRead(b, c.ep, c.wq, deadline, nil, c)
+ if n != 0 {
+ c.ep.ModerateRecvBuf(n)
}
- return numRead, nil
+ return n, err
}
// Write implements net.Conn.Write.
@@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
- read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
+ n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c)
if err != nil {
return 0, nil, err
}
-
- return copy(b, read), fullToUDPAddr(addr), nil
+ return n, fullToUDPAddr(addr), nil
}
func (c *UDPConn) Write(b []byte) (int, error) {
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 8db70a700..5dd1b1b6b 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) {
}
// Read implements io.Reader.
-func (vv *VectorisedView) Read(v View) (copied int, err error) {
- count := len(v)
+func (vv *VectorisedView) Read(b []byte) (copied int, err error) {
+ count := len(b)
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
vv.size -= count
- copy(v[copied:], vv.views[0][:count])
+ copy(b[copied:], vv.views[0][:count])
vv.views[0].TrimFront(count)
copied += count
return copied, nil
}
count -= len(vv.views[0])
- copy(v[copied:], vv.views[0])
+ copy(b[copied:], vv.views[0])
copied += len(vv.views[0])
vv.removeFirst()
}
@@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
return copied
}
+// ReadTo reads up to count bytes from vv to dst. It also removes them from vv
+// unless peek is true.
+func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) {
+ var err error
+ done := 0
+ for _, v := range vv.Views() {
+ remaining := count - done
+ if remaining <= 0 {
+ break
+ }
+ if len(v) > remaining {
+ v = v[:remaining]
+ }
+
+ var n int
+ n, err = dst.Write(v)
+ if n > 0 {
+ done += n
+ }
+ if err != nil {
+ break
+ }
+ }
+ if !peek {
+ vv.TrimFront(done)
+ }
+ return done, err
+}
+
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index 726e54de9..e0ef8a94d 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -235,14 +235,16 @@ func TestToClone(t *testing.T) {
}
}
-func TestVVReadToVV(t *testing.T) {
- testCases := []struct {
- comment string
- vv VectorisedView
- bytesToRead int
- wantBytes string
- leftVV VectorisedView
- }{
+type readToTestCases struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ wantBytes string
+ leftVV VectorisedView
+}
+
+func createReadToTestCases() []readToTestCases {
+ return []readToTestCases{
{
comment: "large VV, short read",
vv: vv(30, "012345678901234567890123456789"),
@@ -279,8 +281,10 @@ func TestVVReadToVV(t *testing.T) {
leftVV: vv(0, ""),
},
}
+}
- for _, tc := range testCases {
+func TestVVReadToVV(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
t.Run(tc.comment, func(t *testing.T) {
var readTo VectorisedView
inSize := tc.vv.Size()
@@ -301,6 +305,52 @@ func TestVVReadToVV(t *testing.T) {
}
}
+func TestVVReadTo(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
+ t.Run(tc.comment, func(t *testing.T) {
+ var dst bytes.Buffer
+ origSize := tc.vv.Size()
+ copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, false /* peek */)
+ if got, want := copied, len(tc.wantBytes); err != nil || got != want {
+ t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ }
+ if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ t.Errorf("got dst = %q, want %q", got, want)
+ }
+ if got, want := tc.vv.Size(), origSize-copied; got != want {
+ t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
+ }
+ if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
+ t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
+ }
+ })
+ }
+}
+
+func TestVVReadToPeek(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
+ t.Run(tc.comment, func(t *testing.T) {
+ var dst bytes.Buffer
+ origSize := tc.vv.Size()
+ origData := string(tc.vv.ToView())
+ copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, true /* peek */)
+ if got, want := copied, len(tc.wantBytes); err != nil || got != want {
+ t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ }
+ if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ t.Errorf("got dst = %q, want %q", got, want)
+ }
+ // Expect tc.vv is unchanged.
+ if got, want := tc.vv.Size(), origSize; got != want {
+ t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
+ }
+ if got, want := string(tc.vv.ToView()), origData; got != want {
+ t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func TestVVRead(t *testing.T) {
testCases := []struct {
comment string
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 13bb5a723..0ac2000ca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -117,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
+ case *ipv6HeaderWithExtHdr:
+ v = ip.HopLimit()
+ default:
+ t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
}
if v != ttl {
t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
@@ -217,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker {
}
}
+// IPv4RouterAlert returns a checker that checks that the RouterAlert option is
+// set in an IPv4 packet.
+func IPv4RouterAlert() NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ iterator := ip.Options().MakeIterator()
+ for {
+ opt, done, err := iterator.Next()
+ if err != nil {
+ t.Fatalf("error acquiring next IPv4 option %s", err)
+ }
+ if done {
+ break
+ }
+ if opt.Type() != header.IPv4OptionRouterAlertType {
+ continue
+ }
+ want := [header.IPv4OptionRouterAlertLength]byte{
+ byte(header.IPv4OptionRouterAlertType),
+ header.IPv4OptionRouterAlertLength,
+ header.IPv4OptionRouterAlertValue,
+ header.IPv4OptionRouterAlertValue,
+ }
+ if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
+ t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
+ }
+ return
+ }
+ t.Errorf("failed to find router alert option in %v", ip.Options())
+ }
+}
+
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -285,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
+// field in ControlMessages.
+func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasOriginalDstAddress {
+ t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
+ } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
+ t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -1013,6 +1066,74 @@ func ICMPv6Payload(want []byte) TransportChecker {
}
}
+// MLD creates a checker that checks that the packet contains a valid MLD
+// message for type of mldType, with potentially additional checks specified by
+// checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// MLD message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.MessageBody()); got < minSize {
+ t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMaxRespDelay(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MaximumResponseDelay(); got != want {
+ t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// MLDMulticastAddress creates a checker that checks the Multicast Address
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMulticastAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MulticastAddress(); got != want {
+ t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
@@ -1032,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
last := h[len(h)-1]
icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.NDPPayload()); got < minSize {
+ if got := len(icmp.MessageBody()); got < minSize {
t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
}
@@ -1066,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
if got := ns.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
@@ -1095,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
@@ -1113,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.SolicitedFlag(); got != want {
t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
@@ -1184,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
ndpOptions(t, na.Options(), opts)
}
}
@@ -1199,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ndpOptions(t, ns.Options(), opts)
}
}
@@ -1224,7 +1345,7 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ rs := header.NDPRouterSolicit(icmp.MessageBody())
ndpOptions(t, rs.Options(), opts)
}
}
@@ -1296,3 +1417,201 @@ func IGMPGroupAddress(want tcpip.Address) TransportChecker {
}
}
}
+
+// IPv6ExtHdrChecker is a function to check an extension header.
+type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
+
+// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
+func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Error("not a valid IPv6 packet")
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var rawPayloadHeader header.IPv6RawPayloadHeader
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
+ return
+ }
+ r, ok := h.(header.IPv6RawPayloadHeader)
+ if ok {
+ rawPayloadHeader = r
+ break
+ }
+ }
+
+ networkHeader := ipv6HeaderWithExtHdr{
+ IPv6: ipv6,
+ transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
+ payload: rawPayloadHeader.Buf.ToView(),
+ }
+
+ for _, checker := range checkers {
+ checker(t, []header.Network{&networkHeader})
+ }
+}
+
+// IPv6ExtHdr checks for the presence of extension headers.
+//
+// All the extension headers in headers will be checked exhaustively in the
+// order provided.
+func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
+ if !ok {
+ t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
+ buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
+ )
+
+ for _, check := range headers {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ return
+ }
+ check(t, h)
+ }
+ // Validate we consumed all headers.
+ //
+ // The next one over should be a raw payload and then iterator should
+ // terminate.
+ wantDone := false
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done != wantDone {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
+ return
+ }
+ if done {
+ break
+ }
+ if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
+ t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
+ continue
+ }
+ wantDone = true
+ }
+ }
+}
+
+var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
+
+// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
+// extension headers into consideration, which is not the case with vanilla
+// header.IPv6.
+type ipv6HeaderWithExtHdr struct {
+ header.IPv6
+ transport tcpip.TransportProtocolNumber
+ payload []byte
+}
+
+// TransportProtocol implements header.Network.
+func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
+ return h.transport
+}
+
+// Payload implements header.Network.
+func (h *ipv6HeaderWithExtHdr) Payload() []byte {
+ return h.payload
+}
+
+// IPv6ExtHdrOptionChecker is a function to check an extension header option.
+type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
+
+// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
+// extension header and validates the containing options with checkers.
+//
+// checkers must exhaustively contain all the expected options.
+func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
+ return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
+ t.Helper()
+
+ hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
+ if !ok {
+ t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
+ return
+ }
+ optionsIterator := hbh.Iter()
+ for _, f := range checkers {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ f(t, opt)
+ }
+ // Validate all options were consumed.
+ for {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if !done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ if done {
+ break
+ }
+ }
+ }
+}
+
+// IPv6RouterAlert validates that an extension header option is the RouterAlert
+// option and matches on its value.
+func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
+ if !ok {
+ t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
+ return
+ }
+ if routerAlert.Value != want {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
+ }
+ }
+}
+
+// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
+func IgnoreCmpPath(paths ...string) cmp.Option {
+ ignores := map[string]struct{}{}
+ for _, path := range paths {
+ ignores[path] = struct{}{}
+ }
+ return cmp.FilterPath(func(path cmp.Path) bool {
+ _, ok := ignores[path.String()]
+ return ok
+ }, cmp.Ignore())
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 144093c3a..0bdc12d53 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -42,6 +42,7 @@ go_test(
srcs = [
"checksum_test.go",
"igmp_test.go",
+ "ipv4_test.go",
"ipv6_test.go",
"ipversion_test.go",
"tcp_test.go",
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 309403482..5ab20ee86 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -19,6 +19,7 @@ package header_test
import (
"fmt"
"math/rand"
+ "sync"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) {
}
}
}
+
+func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) {
+ // icmpChecksum should not do any modifications of the header to
+ // calculate its checksum. Let's call it from a few go-routines and the
+ // race detector will trigger a warning if there are any concurrent
+ // read/write accesses.
+
+ const concurrency = 5
+ start := make(chan int)
+ ready := make(chan bool, concurrency)
+ var wg sync.WaitGroup
+ wg.Add(concurrency)
+ defer wg.Wait()
+
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ defer wg.Done()
+
+ ready <- true
+ <-start
+
+ if got := headerChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ if got := icmpChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ }()
+ }
+ for i := 0; i < concurrency; i++ {
+ <-ready
+ }
+ close(start)
+}
+
+func TestICMPv4Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:5]),
+ buffer.NewViewFromBytes(buf[5:]),
+ })
+
+ want := header.Checksum(vv.ToView(), 0)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv4Checksum(h, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}
+
+func TestICMPv6Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:7]),
+ buffer.NewViewFromBytes(buf[7:10]),
+ buffer.NewViewFromBytes(buf[10:]),
+ })
+
+ dst := header.IPv6Loopback
+ src := header.IPv6Loopback
+
+ want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+ want = header.Checksum(vv.ToView(), want)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv6Checksum(h, src, dst, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 2f13dea6a..5f9b8e9e2 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -199,17 +200,24 @@ func (b ICMPv4) SetSequence(sequence uint16) {
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := uint16(0)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
+ xsum := ChecksumVV(vv, 0)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
+ return ^xsum
+}
- return xsum
+// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when
+// a packet having a `net` header causing an ICMP error.
+func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin {
+ switch net {
+ case IPv4ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP
+ case IPv6ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP6
+ default:
+ panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net))
+ }
}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 4303fc5d5..eca9750ab 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -115,6 +115,12 @@ const (
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
+
+ // Multicast Listener Discovery (MLD) messages, see RFC 2710.
+
+ ICMPv6MulticastListenerQuery ICMPv6Type = 130
+ ICMPv6MulticastListenerReport ICMPv6Type = 131
+ ICMPv6MulticastListenerDone ICMPv6Type = 132
)
// IsErrorType returns true if the receiver is an ICMP error type.
@@ -245,10 +251,9 @@ func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
-// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6
-// packet's message body as defined by RFC 4443 section 2.1; the portion of the
-// ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
-func (b ICMPv6) NDPPayload() []byte {
+// MessageBody returns the message body as defined by RFC 4443 section 2.1; the
+// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
+func (b ICMPv6) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
@@ -260,22 +265,13 @@ func (b ICMPv6) Payload() []byte {
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := Checksum([]byte(src), 0)
- xsum = Checksum([]byte(dst), xsum)
- var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
- xsum = Checksum(upperLayerLength[:], xsum)
- xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
-
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
-
- return xsum
+ xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+
+ xsum = ChecksumVV(vv, xsum)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
+
+ return ^xsum
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 5fddd2af6..e6103f4bc 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -100,7 +100,7 @@ type IPv4Fields struct {
//
// That leaves ten 32 bit (4 byte) fields for options. An attempt to encode
// more will fail.
- Options IPv4Options
+ Options IPv4OptionsSerializer
}
// IPv4 is an IPv4 header.
@@ -285,18 +285,17 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
-// IPv4Options is a buffer that holds all the raw IP options.
-type IPv4Options []byte
-
-// SizeWithPadding implements stack.NetOptions.
-// It reports the size to allocate for the Options. RFC 791 page 23 (end of
-// section 3.1) says of the padding at the end of the options:
+// padIPv4OptionsLength returns the total length for IPv4 options of length l
+// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
// header ends on a 32 bit boundary.
-func (o IPv4Options) SizeWithPadding() int {
- return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1)
+func padIPv4OptionsLength(length uint8) uint8 {
+ return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1)
}
+// IPv4Options is a buffer that holds all the raw IP options.
+type IPv4Options []byte
+
// Options returns a buffer holding the options.
func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
@@ -375,26 +374,16 @@ func (b IPv4) CalculateChecksum() uint16 {
func (b IPv4) Encode(i *IPv4Fields) {
// The size of the options defines the size of the whole header and thus the
// IHL field. Options are rare and this is a heavily used function so it is
- // worth a bit of optimisation here to keep the copy out of the fast path.
- hdrLen := IPv4MinimumSize
+ // worth a bit of optimisation here to keep the serializer out of the fast
+ // path.
+ hdrLen := uint8(IPv4MinimumSize)
if len(i.Options) != 0 {
- // SizeWithPadding is always >= len(i.Options).
- aLen := i.Options.SizeWithPadding()
- hdrLen += aLen
- if hdrLen > len(b) {
- panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen))
- }
- opts := b[options:]
- // This avoids bounds checks on the next line(s) which would happen even
- // if there's no work to do.
- if n := copy(opts, i.Options); n != aLen {
- padding := opts[n:][:aLen-n]
- for i := range padding {
- padding[i] = 0
- }
- }
+ hdrLen += i.Options.Serialize(b[options:])
}
- b.SetHeaderLength(uint8(hdrLen))
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize))
+ }
+ b.SetHeaderLength(hdrLen)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
@@ -474,6 +463,10 @@ const (
// options and may appear multiple times.
IPv4OptionNOPType IPv4OptionType = 1
+ // IPv4OptionRouterAlertType is the option type for the Router Alert option,
+ // defined in RFC 2113 Section 2.1.
+ IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80
+
// IPv4OptionRecordRouteType is used by each router on the path of the packet
// to record its path. It is carried over to an Echo Reply.
IPv4OptionRecordRouteType IPv4OptionType = 7
@@ -874,3 +867,162 @@ func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
// Contents implements IPv4Option.
func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
+
+// Router Alert option specific related constants.
+//
+// from RFC 2113 section 2.1:
+//
+// +--------+--------+--------+--------+
+// |10010100|00000100| 2 octet value |
+// +--------+--------+--------+--------+
+//
+// Type:
+// Copied flag: 1 (all fragments must carry the option)
+// Option class: 0 (control)
+// Option number: 20 (decimal)
+//
+// Length: 4
+//
+// Value: A two octet code with the following values:
+// 0 - Router shall examine packet
+// 1-65535 - Reserved
+const (
+ // IPv4OptionRouterAlertLength is the length of a Router Alert option.
+ IPv4OptionRouterAlertLength = 4
+
+ // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit
+ // payload of the router alert option.
+ IPv4OptionRouterAlertValue = 0
+
+ // iPv4OptionRouterAlertValueOffset is the offset for the value of a
+ // RouterAlert option.
+ iPv4OptionRouterAlertValueOffset = 2
+)
+
+// IPv4SerializableOption is an interface to represent serializable IPv4 option
+// types.
+type IPv4SerializableOption interface {
+ // optionType returns the type identifier of the option.
+ optionType() IPv4OptionType
+}
+
+// IPv4SerializableOptionPayload is an interface providing serialization of the
+// payload of an IPv4 option.
+type IPv4SerializableOptionPayload interface {
+ // length returns the size of the payload.
+ length() uint8
+
+ // serializeInto serializes the payload into the provided byte buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // Length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MUST panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto(buffer []byte) uint8
+}
+
+// IPv4OptionsSerializer is a serializer for IPv4 options.
+type IPv4OptionsSerializer []IPv4SerializableOption
+
+// Length returns the total number of bytes required to serialize the options.
+func (s IPv4OptionsSerializer) Length() uint8 {
+ var total uint8
+ for _, opt := range s {
+ total++
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Add 1 to reported length to account for the length byte.
+ total += 1 + withPayload.length()
+ }
+ }
+ return padIPv4OptionsLength(total)
+}
+
+// Serialize serializes the provided list of IPV4 options into b.
+//
+// Note, b must be of sufficient size to hold all the options in s. See
+// IPv4OptionsSerializer.Length for details on the getting the total size
+// of a serialized IPv4OptionsSerializer.
+//
+// Serialize panics if b is not of sufficient size to hold all the options in s.
+func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 {
+ var total uint8
+ for _, opt := range s {
+ ty := opt.optionType()
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Serialize first to reduce bounds checks.
+ l := 2 + withPayload.serializeInto(b[2:])
+ b[0] = byte(ty)
+ b[1] = l
+ b = b[l:]
+ total += l
+ continue
+ }
+ // Options without payload consist only of the type field.
+ //
+ // NB: Repeating code from the branch above is intentional to minimize
+ // bounds checks.
+ b[0] = byte(ty)
+ b = b[1:]
+ total++
+ }
+
+ // According to RFC 791:
+ //
+ // The internet header padding is used to ensure that the internet
+ // header ends on a 32 bit boundary. The padding is zero.
+ padded := padIPv4OptionsLength(total)
+ b = b[:padded-total]
+ for i := range b {
+ b[i] = 0
+ }
+ return padded
+}
+
+var _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil)
+var _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil)
+
+// IPv4SerializableRouterAlertOption provides serialization of the Router Alert
+// IPv4 option according to RFC 2113.
+type IPv4SerializableRouterAlertOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType {
+ return IPv4OptionRouterAlertType
+}
+
+// Length implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) length() uint8 {
+ return IPv4OptionRouterAlertLength - iPv4OptionRouterAlertValueOffset
+}
+
+// SerializeInto implements IPv4SerializableOption.
+func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 {
+ binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue)
+ return o.length()
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil)
+
+// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option.
+type IPv4SerializableNOPOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableNOPOption) optionType() IPv4OptionType {
+ return IPv4OptionNOPType
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil)
+
+// IPv4SerializableListEndOption provides serialization for the IPv4 List End
+// option.
+type IPv4SerializableListEndOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableListEndOption) optionType() IPv4OptionType {
+ return IPv4OptionListEndType
+}
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
new file mode 100644
index 000000000..6475cd694
--- /dev/null
+++ b/pkg/tcpip/header/ipv4_test.go
@@ -0,0 +1,179 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestIPv4OptionsSerializer(t *testing.T) {
+ optCases := []struct {
+ name string
+ option []header.IPv4SerializableOption
+ expect []byte
+ }{
+ {
+ name: "NOP",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ },
+ expect: []byte{1, 0, 0, 0},
+ },
+ {
+ name: "ListEnd",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableListEndOption{},
+ },
+ expect: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{148, 4, 0, 0},
+ }, {
+ name: "NOP and RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{1, 148, 4, 0, 0, 0, 0, 0},
+ },
+ }
+
+ for _, opt := range optCases {
+ t.Run(opt.name, func(t *testing.T) {
+ s := header.IPv4OptionsSerializer(opt.option)
+ l := s.Length()
+ if got := len(opt.expect); got != int(l) {
+ t.Fatalf("s.Length() = %d, want = %d", got, l)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with full bytes to ensure padding is being set
+ // correctly.
+ b[i] = 0xFF
+ }
+ if serializedLength := s.Serialize(b); serializedLength != l {
+ t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l)
+ }
+ if diff := cmp.Diff(opt.expect, b); diff != "" {
+ t.Errorf("mismatched serialized option (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
+// fields when options are supplied.
+func TestIPv4EncodeOptions(t *testing.T) {
+ tests := []struct {
+ name string
+ numberOfNops int
+ encodedOptions header.IPv4Options // reply should look like this
+ wantIHL int
+ }{
+ {
+ name: "valid no options",
+ wantIHL: header.IPv4MinimumSize,
+ },
+ {
+ name: "one byte options",
+ numberOfNops: 1,
+ encodedOptions: header.IPv4Options{1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "two byte options",
+ numberOfNops: 2,
+ encodedOptions: header.IPv4Options{1, 1, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "three byte options",
+ numberOfNops: 3,
+ encodedOptions: header.IPv4Options{1, 1, 1, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "four byte options",
+ numberOfNops: 4,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "five byte options",
+ numberOfNops: 5,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 8,
+ },
+ {
+ name: "thirty nine byte options",
+ numberOfNops: 39,
+ encodedOptions: header.IPv4Options{
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 0,
+ },
+ wantIHL: header.IPv4MinimumSize + 40,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops))
+ for i := range serializeOpts {
+ serializeOpts[i] = &header.IPv4SerializableNOPOption{}
+ }
+ paddedOptionLength := serializeOpts.Length()
+ ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength)
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength)
+ hdr := buffer.NewPrependable(int(totalLen))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ // To check the padding works, poison the last byte of the options space.
+ if paddedOptionLength != serializeOpts.Length() {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0xff
+ ip.SetHeaderLength(0)
+ }
+ ip.Encode(&header.IPv4Fields{
+ Options: serializeOpts,
+ })
+ options := ip.Options()
+ wantOptions := test.encodedOptions
+ if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
+ t.Errorf("got IHL of %d, want %d", got, want)
+ }
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(wantOptions) == 0 && len(options) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(wantOptions, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 55d09355a..5580d6a78 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -18,7 +18,6 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
- "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -48,11 +47,13 @@ type IPv6Fields struct {
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
- // PayloadLength is the "payload length" field of an IPv6 packet.
+ // PayloadLength is the "payload length" field of an IPv6 packet, including
+ // the length of all extension headers.
PayloadLength uint16
- // NextHeader is the "next header" field of an IPv6 packet.
- NextHeader uint8
+ // TransportProtocol is the transport layer protocol number. Serialized in the
+ // last "next header" field of the IPv6 header + extension headers.
+ TransportProtocol tcpip.TransportProtocolNumber
// HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
@@ -62,6 +63,9 @@ type IPv6Fields struct {
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr tcpip.Address
+
+ // ExtensionHeaders are the extension headers following the IPv6 header.
+ ExtensionHeaders IPv6ExtHdrSerializer
}
// IPv6 represents an ipv6 header stored in a byte array.
@@ -148,13 +152,17 @@ const (
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
// be contained within this subnet.
-var IPv6EmptySubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
- if err != nil {
- panic(err)
- }
- return subnet
-}()
+var IPv6EmptySubnet = tcpip.AddressWithPrefix{
+ Address: IPv6Any,
+ PrefixLen: 0,
+}.Subnet()
+
+// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined
+// by RFC 4291 section 2.5.5.
+var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{
+ Address: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00",
+ PrefixLen: 96,
+}.Subnet()
// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
// by RFC 4291 section 2.5.6.
@@ -253,12 +261,14 @@ func (IPv6) SetChecksum(uint16) {
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
+ extHdr := b[IPv6MinimumSize:]
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
- b[IPv6NextHeaderOffset] = i.NextHeader
b[hopLimit] = i.HopLimit
b.SetSourceAddress(i.SrcAddr)
b.SetDestinationAddress(i.DstAddr)
+ nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
+ b[IPv6NextHeaderOffset] = nextHeader
}
// IsValid performs basic validation on the packet.
@@ -286,7 +296,7 @@ func IsV4MappedAddress(addr tcpip.Address) bool {
return false
}
- return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
+ return IPv4MappedIPv6Subnet.Contains(addr)
}
// IsV6MulticastAddress determines if the provided address is an IPv6
@@ -392,17 +402,6 @@ func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
}
-// IsV6UniqueLocalAddress determines if the provided address is an IPv6
-// unique-local address (within the prefix FC00::/7).
-func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
- if len(addr) != IPv6AddressSize {
- return false
- }
- // According to RFC 4193 section 3.1, a unique local address has the prefix
- // FC00::/7.
- return (addr[0] & 0xfe) == 0xfc
-}
-
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
//
@@ -449,9 +448,6 @@ const (
// LinkLocalScope indicates a link-local address.
LinkLocalScope IPv6AddressScope = iota
- // UniqueLocalScope indicates a unique-local address.
- UniqueLocalScope
-
// GlobalScope indicates a global address.
GlobalScope
)
@@ -469,9 +465,6 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
case IsV6LinkLocalAddress(addr):
return LinkLocalScope, nil
- case IsV6UniqueLocalAddress(addr):
- return UniqueLocalScope, nil
-
default:
return GlobalScope, nil
}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 571eae233..f18981332 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -18,9 +18,12 @@ import (
"bufio"
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
+ "math"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -75,8 +78,8 @@ const (
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
- // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
- // discard from the Fragment Offset.
+ // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
+ // Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetShift = 3
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
@@ -114,6 +117,37 @@ const (
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
)
+// padIPv6OptionsLength returns the total length for IPv6 options of length l
+// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
+func padIPv6OptionsLength(length int) int {
+ return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
+}
+
+// padIPv6Option fills b with the appropriate padding options depending on its
+// length.
+func padIPv6Option(b []byte) {
+ switch len(b) {
+ case 0: // No padding needed.
+ case 1: // Pad with Pad1.
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
+ default: // Pad with PadN.
+ s := b[ipv6ExtHdrOptionPayloadOffset:]
+ for i := range s {
+ s[i] = 0
+ }
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
+ b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
+ }
+}
+
+// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
+// serialize an option at headerOffset with alignment requirements
+// [align]n + alignOffset.
+func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
+ padLen := headerOffset - alignOffset
+ return ((padLen + align - 1) & ^(align - 1)) - padLen
+}
+
// IPv6PayloadHeader is implemented by the various headers that can be found
// in an IPv6 payload.
//
@@ -206,29 +240,55 @@ type IPv6ExtHdrOption interface {
isIPv6ExtHdrOption()
}
-// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
-type IPv6ExtHdrOptionIndentifier uint8
+// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIdentifier uint8
const (
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
- ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
// ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
- ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
+
+ // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
+ // Alert Hop by Hop option as defined in RFC 2711 section 2.1.
+ ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
+
+ // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
+ // option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionTypeOffset = 0
+
+ // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionLengthOffset = 1
+
+ // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionPayloadOffset = 2
)
+// ipv6UnknownActionFromIdentifier maps an extension header option's
+// identifier's high bits to the action to take when the identifier is unknown.
+func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
+// is malformed.
+var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
+
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
- Identifier IPv6ExtHdrOptionIndentifier
+ Identifier IPv6ExtHdrOptionIdentifier
Data []byte
}
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
- return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+ return ipv6UnknownActionFromIdentifier(o.Identifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
@@ -251,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// options buffer has been exhausted and we are done iterating.
return nil, true, nil
}
- id := IPv6ExtHdrOptionIndentifier(temp)
+ id := IPv6ExtHdrOptionIdentifier(temp)
// If the option identifier indicates the option is a Pad1 option, then we
// know the option does not have Length and Data fields. End processing of
@@ -294,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
continue
+ case ipv6RouterAlertHopByHopOptionIdentifier:
+ var routerAlertValue [ipv6RouterAlertPayloadLength]byte
+ if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil {
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ default:
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
+ }
+ } else if n != int(length) {
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ }
+ return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
bytes := make([]byte, length)
if n, err := io.ReadFull(&i.reader, bytes); err != nil {
@@ -609,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
}
+
+// IPv6SerializableExtHdr provides serialization for IPv6 extension
+// headers.
+type IPv6SerializableExtHdr interface {
+ // identifier returns the assigned IPv6 header identifier for this extension
+ // header.
+ identifier() IPv6ExtensionHeaderIdentifier
+
+ // length returns the total serialized length in bytes of this extension
+ // header, including the common next header and length fields.
+ length() int
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer and with the provided nextHeader value.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto returns the number of bytes that was used to serialize the
+ // receiver. Implementers must only use the number of bytes required to
+ // serialize the receiver. Callers MAY provide a larger buffer than required
+ // to serialize into.
+ serializeInto(nextHeader uint8, b []byte) int
+}
+
+var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
+
+// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
+// options extension header.
+type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
+
+const (
+ // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
+ // in a hop by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrNextHeaderOffset = 0
+
+ // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
+ // by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrLengthOffset = 1
+
+ // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
+ // hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrOptionsOffset = 2
+
+ // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
+ // words in a hop by hop extension header's length field, as stated in RFC
+ // 8200 section 4.3:
+ // Length of the Hop-by-Hop Options header in 8-octet units,
+ // not including the first 8 octets.
+ ipv6HopByHopExtHdrUnaccountedLenWords = 1
+)
+
+// identifier implements IPv6SerializableExtHdr.
+func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6HopByHopOptionsExtHdrIdentifier
+}
+
+// length implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) length() int {
+ var total int
+ for _, opt := range h {
+ align, alignOffset := opt.alignment()
+ total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
+ total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
+ }
+ // Account for next header and total length fields and add padding.
+ return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
+ totalLength := ipv6HopByHopExtHdrOptionsOffset
+ for _, opt := range h {
+ // Calculate alignment requirements and pad buffer if necessary.
+ align, alignOffset := opt.alignment()
+ padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
+ if padLen != 0 {
+ padIPv6Option(optBuffer[:padLen])
+ totalLength += padLen
+ optBuffer = optBuffer[padLen:]
+ }
+
+ l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
+ optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
+ optBuffer[ipv6ExtHdrOptionLengthOffset] = l
+ l += ipv6ExtHdrOptionPayloadOffset
+ totalLength += int(l)
+ optBuffer = optBuffer[l:]
+ }
+ padded := padIPv6OptionsLength(totalLength)
+ if padded != totalLength {
+ padIPv6Option(optBuffer[:padded-totalLength])
+ totalLength = padded
+ }
+ wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
+ if wordsLen > math.MaxUint8 {
+ panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
+ }
+ b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
+ b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
+ return totalLength
+}
+
+// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
+type IPv6SerializableHopByHopOption interface {
+ // identifier returns the option identifier of this Hop by Hop option.
+ identifier() IPv6ExtHdrOptionIdentifier
+
+ // length returns the *payload* size of the option (not considering the type
+ // and length fields).
+ length() uint8
+
+ // alignment returns the alignment requirements from this option.
+ //
+ // Alignment requirements take the form [align]n + offset as specified in
+ // RFC 8200 section 4.2. The alignment requirement is on the offset between
+ // the option type byte and the start of the hop by hop header.
+ //
+ // align must be a power of 2.
+ alignment() (align int, offset int)
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) uint8
+}
+
+var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
+
+// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
+// RFC 2711 section 2.1.
+type IPv6RouterAlertOption struct {
+ Value IPv6RouterAlertValue
+}
+
+// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
+type IPv6RouterAlertValue uint16
+
+const (
+ // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
+ // Discovery message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertMLD IPv6RouterAlertValue = 0
+ // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
+ // defined in RFC 2711 section 2.1.
+ IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
+ // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
+ // Networks message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
+
+ // ipv6RouterAlertPayloadLength is the length of the Router Alert payload
+ // as defined in RFC 2711.
+ ipv6RouterAlertPayloadLength = 2
+
+ // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
+ // Router Alert option defined as 2n+0 in RFC 2711.
+ ipv6RouterAlertAlignmentRequirement = 2
+
+ // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
+ // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
+ // 2.1.
+ ipv6RouterAlertAlignmentOffsetRequirement = 0
+)
+
+// UnknownAction implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
+ return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
+ return ipv6RouterAlertHopByHopOptionIdentifier
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) length() uint8 {
+ return ipv6RouterAlertPayloadLength
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) alignment() (int, int) {
+ // From RFC 2711 section 2.1:
+ // Alignment requirement: 2n+0.
+ return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
+ binary.BigEndian.PutUint16(b, uint16(o.Value))
+ return ipv6RouterAlertPayloadLength
+}
+
+// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
+type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
+
+// Serialize serializes the provided list of IPv6 extension headers into b.
+//
+// Note, b must be of sufficient size to hold all the headers in s. See
+// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
+// serialized IPv6ExtHdrSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+//
+// Serialize takes the transportProtocol value to be used as the last extension
+// header's Next Header value and returns the header identifier of the first
+// serialized extension header and the total serialized length.
+func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
+ nextHeader := uint8(transportProtocol)
+ if len(s) == 0 {
+ return nextHeader, 0
+ }
+ var totalLength int
+ for i, h := range s[:len(s)-1] {
+ length := h.serializeInto(uint8(s[i+1].identifier()), b)
+ b = b[length:]
+ totalLength += length
+ }
+ totalLength += s[len(s)-1].serializeInto(nextHeader, b)
+ return uint8(s[0].identifier()), totalLength
+}
+
+// Length returns the total number of bytes required to serialize the extension
+// headers.
+func (s IPv6ExtHdrSerializer) Length() int {
+ var totalLength int
+ for _, h := range s {
+ totalLength += h.length()
+ }
+ return totalLength
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
index ab20c5f37..65adc6250 100644
--- a/pkg/tcpip/header/ipv6_extension_headers_test.go
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool
func TestIPv6UnknownExtHdrOption(t *testing.T) {
tests := []struct {
name string
- identifier IPv6ExtHdrOptionIndentifier
+ identifier IPv6ExtHdrOptionIdentifier
expectedUnknownAction IPv6OptionUnknownAction
}{
{
@@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
bytes: []byte{1, 3},
err: io.ErrUnexpectedEOF,
},
+ {
+ name: "Router alert without data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data and Pad1",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with extra data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with missing data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1},
+ err: io.ErrUnexpectedEOF,
+ },
}
check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
@@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) {
})
}
}
+
+var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil)
+
+// dummyHbHOptionSerializer provides a generic implementation of
+// IPv6SerializableHopByHopOption for use in tests.
+type dummyHbHOptionSerializer struct {
+ id IPv6ExtHdrOptionIdentifier
+ payload []byte
+ align int
+ alignOffset int
+}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) length() uint8 {
+ return uint8(len(s.payload))
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) alignment() (int, int) {
+ align := 1
+ if s.align != 0 {
+ align = s.align
+ }
+ return align, s.alignOffset
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 {
+ return uint8(copy(b, s.payload))
+}
+
+func TestIPv6HopByHopSerializer(t *testing.T) {
+ validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ dummy, ok := serializable.(*dummyHbHOptionSerializer)
+ if !ok {
+ t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable)
+ }
+ unknown, ok := deserialized.(*IPv6UnknownExtHdrOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{})
+ }
+ if dummy.id != unknown.Identifier {
+ t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id)
+ }
+ if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" {
+ t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff)
+ }
+ }
+ tests := []struct {
+ name string
+ nextHeader uint8
+ options []IPv6SerializableHopByHopOption
+ expect []byte
+ validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption)
+ }{
+ {
+ name: "single option",
+ nextHeader: 13,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 15,
+ payload: []byte{9, 8, 7, 6},
+ },
+ },
+ expect: []byte{13, 0, 15, 4, 9, 8, 7, 6},
+ validate: validateDummies,
+ },
+ {
+ name: "short option padN zero",
+ nextHeader: 88,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5},
+ },
+ },
+ expect: []byte{88, 0, 22, 2, 4, 5, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "short option pad1",
+ nextHeader: 11,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 33,
+ payload: []byte{1, 2, 3},
+ },
+ },
+ expect: []byte{11, 0, 33, 3, 1, 2, 3, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "long option padN",
+ nextHeader: 55,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 77,
+ payload: []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ },
+ },
+ expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 2n",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 2,
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 8n+1",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 8,
+ alignOffset: 1,
+ },
+ },
+ expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "no options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{},
+ expect: []byte{33, 0, 1, 4, 0, 0, 0, 0},
+ },
+ {
+ name: "Router Alert",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}},
+ expect: []byte{33, 0, 5, 2, 0, 0, 1, 0},
+ validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ routerAlert, ok := deserialized.(*IPv6RouterAlertOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized)
+ }
+ if routerAlert.Value != IPv6RouterAlertMLD {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD)
+ }
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6SerializableHopByHopExtHdr(test.options)
+ length := s.length()
+ if length != len(test.expect) {
+ t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect))
+ }
+ b := make([]byte, length)
+ for i := range b {
+ // Fill the buffer with ones to ensure all padding is correctly set.
+ b[i] = 0xFF
+ }
+ if got := s.serializeInto(test.nextHeader, b); got != length {
+ t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length)
+ }
+ if diff := cmp.Diff(test.expect, b); diff != "" {
+ t.Fatalf("serialization mismatch (-want +got):\n%s", diff)
+ }
+
+ // Deserialize the options and verify them.
+ optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit
+ iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter()
+ for _, testOpt := range test.options {
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ test.validate(t, testOpt, opt)
+ }
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if !done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ })
+ }
+}
+
+var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil)
+
+// dummyIPv6ExtHdrSerializer provides a generic implementation of
+// IPv6SerializableExtHdr for use in tests.
+//
+// The dummy header always carries the nextHeader value in the first byte.
+type dummyIPv6ExtHdrSerializer struct {
+ id IPv6ExtensionHeaderIdentifier
+ headerContents []byte
+}
+
+// identifier implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) length() int {
+ return len(s.headerContents) + 1
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int {
+ b[0] = nextHeader
+ return copy(b[1:], s.headerContents) + 1
+}
+
+func TestIPv6ExtHdrSerializer(t *testing.T) {
+ tests := []struct {
+ name string
+ headers []IPv6SerializableExtHdr
+ nextHeader tcpip.TransportProtocolNumber
+ expectSerialized []byte
+ expectNextHeader uint8
+ }{
+ {
+ name: "one header",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 15,
+ headerContents: []byte{1, 2, 3, 4},
+ },
+ },
+ nextHeader: TCPProtocolNumber,
+ expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4},
+ expectNextHeader: 15,
+ },
+ {
+ name: "two headers",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 22,
+ headerContents: []byte{1, 2, 3},
+ },
+ &dummyIPv6ExtHdrSerializer{
+ id: 23,
+ headerContents: []byte{4, 5, 6},
+ },
+ },
+ nextHeader: ICMPv6ProtocolNumber,
+ expectSerialized: []byte{
+ 23, 1, 2, 3,
+ byte(ICMPv6ProtocolNumber), 4, 5, 6,
+ },
+ expectNextHeader: 22,
+ },
+ {
+ name: "no headers",
+ headers: []IPv6SerializableExtHdr{},
+ nextHeader: UDPProtocolNumber,
+ expectSerialized: []byte{},
+ expectNextHeader: byte(UDPProtocolNumber),
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6ExtHdrSerializer(test.headers)
+ l := s.Length()
+ if got, want := l, len(test.expectSerialized); got != want {
+ t.Fatalf("got serialized length = %d, want = %d", got, want)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with garbage to make sure we're writing to all bytes.
+ b[i] = 0xFF
+ }
+ nextHeader, serializedLen := s.Serialize(test.nextHeader, b)
+ if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader {
+ t.Errorf(
+ "got s.Serialize(..) = (%d, %d), want = (%d, %d)",
+ nextHeader,
+ serializedLen,
+ test.expectNextHeader,
+ len(test.expectSerialized),
+ )
+ }
+ if diff := cmp.Diff(test.expectSerialized, b); diff != "" {
+ t.Errorf("serialization mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
index 018555a26..9d09f32eb 100644
--- a/pkg/tcpip/header/ipv6_fragment.go
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -27,12 +27,11 @@ const (
idV6 = 4
)
-// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
-// fields of a packet that needs to be encoded.
-type IPv6FragmentFields struct {
- // NextHeader is the "next header" field of an IPv6 fragment.
- NextHeader uint8
+var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
+// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
+// extension header as defined in RFC 8200 section 4.5.
+type IPv6SerializableFragmentExtHdr struct {
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
@@ -43,6 +42,29 @@ type IPv6FragmentFields struct {
Identification uint32
}
+// identifier implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6FragmentHeader
+}
+
+// length implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) length() int {
+ return IPv6FragmentHeaderSize
+}
+
+// serializeInto implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ // Prevent too many bounds checks.
+ _ = b[IPv6FragmentHeaderSize:]
+ binary.BigEndian.PutUint32(b[idV6:], h.Identification)
+ binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
+ b[nextHdrFrag] = nextHeader
+ if h.M {
+ b[more] |= ipv6FragmentExtHdrMFlagMask
+ }
+ return IPv6FragmentHeaderSize
+}
+
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
@@ -58,16 +80,6 @@ const (
IPv6FragmentHeaderSize = 8
)
-// Encode encodes all the fields of the ipv6 fragment.
-func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
- b[nextHdrFrag] = i.NextHeader
- binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
- if i.M {
- b[more] |= 1
- }
- binary.BigEndian.PutUint32(b[idV6:], i.Identification)
-}
-
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 426a873b1..e3fbd64f3 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -215,48 +215,6 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
}
}
-func TestIsV6UniqueLocalAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid Unique 1",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Valid Unique 2",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Link Local",
- addr: linkLocalAddr,
- expected: false,
- },
- {
- name: "Global",
- addr: globalAddr,
- expected: false,
- },
- {
- name: "IPv4",
- addr: "\x01\x02\x03\x04",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
-
func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
tests := []struct {
name string
@@ -346,7 +304,7 @@ func TestScopeForIPv6Address(t *testing.T) {
{
name: "Unique Local",
addr: uniqueLocalAddr1,
- scope: header.UniqueLocalScope,
+ scope: header.GlobalScope,
err: nil,
},
{
diff --git a/pkg/tcpip/header/mld.go b/pkg/tcpip/header/mld.go
index f70623092..ffe03c76a 100644
--- a/pkg/tcpip/header/mld.go
+++ b/pkg/tcpip/header/mld.go
@@ -23,6 +23,13 @@ import (
)
const (
+ // MLDMinimumSize is the minimum size for an MLD message.
+ MLDMinimumSize = 20
+
+ // MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as
+ // per RFC 2710 section 3.
+ MLDHopLimit = 1
+
// mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay
// field within MLD.
mldMaximumResponseDelayOffset = 0
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 5d3975c56..554242f0c 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -298,7 +298,7 @@ func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
return it, nil
}
-// Serialize serializes the provided list of NDP options into o.
+// Serialize serializes the provided list of NDP options into b.
//
// Note, b must be of sufficient size to hold all the options in s. See
// NDPOptionsSerializer.Length for details on the getting the total size
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a7f5f4979..d9f8e3b35 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -31,7 +31,7 @@ type PacketInfo struct {
Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
- Route stack.Route
+ Route stack.RouteInfo
}
// Notification is the interface for receiving notification from the packet
@@ -230,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket stores outbound packets into the channel.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
e.q.Write(p)
@@ -248,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets stores outbound packets into the channel.
func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
if !e.q.Write(p) {
diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD
index ec92ed623..0ae0d201a 100644
--- a/pkg/tcpip/link/ethernet/BUILD
+++ b/pkg/tcpip/link/ethernet/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -13,3 +13,17 @@ go_library(
"//pkg/tcpip/stack",
],
)
+
+go_test(
+ name = "ethernet_test",
+ size = "small",
+ srcs = ["ethernet_test.go"],
+ deps = [
+ ":ethernet",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index 3eef7cd56..89e3e6164 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -49,10 +49,10 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP
return
}
+ // Note, there is no need to check the destination link address here since
+ // the ethernet hardware filters frames based on their destination addresses.
eth := header.Ethernet(hdr)
- if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
- e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt)
- }
+ e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, eth.DestinationAddress() /* local */, eth.Type() /* protocol */, pkt)
}
// Capabilities implements stack.LinkEndpoint.
@@ -62,7 +62,7 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt)
return e.Endpoint.WritePacket(r, gso, proto, pkt)
}
@@ -71,7 +71,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkAddr := e.Endpoint.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/ethernet/ethernet_test.go b/pkg/tcpip/link/ethernet/ethernet_test.go
new file mode 100644
index 000000000..08a7f1ce1
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet_test.go
@@ -0,0 +1,71 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ethernet_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkDispatcher = (*testNetworkDispatcher)(nil)
+
+type testNetworkDispatcher struct {
+ networkPackets int
+}
+
+func (t *testNetworkDispatcher) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+ t.networkPackets++
+}
+
+func (*testNetworkDispatcher) DeliverOutboundPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+func TestDeliverNetworkPacket(t *testing.T) {
+ const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ otherLinkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ otherLinkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ )
+
+ e := ethernet.New(channel.New(0, 0, linkAddr))
+ var networkDispatcher testNetworkDispatcher
+ e.Attach(&networkDispatcher)
+
+ if networkDispatcher.networkPackets != 0 {
+ t.Fatalf("got networkDispatcher.networkPackets = %d, want = 0", networkDispatcher.networkPackets)
+ }
+
+ // An ethernet frame with a destination link address that is not assigned to
+ // our ethernet link endpoint should still be delivered to the network
+ // dispatcher since the ethernet endpoint is not expected to filter frames.
+ eth := buffer.NewView(header.EthernetMinimumSize)
+ header.Ethernet(eth).Encode(&header.EthernetFields{
+ SrcAddr: otherLinkAddr1,
+ DstAddr: otherLinkAddr2,
+ Type: header.IPv4ProtocolNumber,
+ })
+ e.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: eth.ToVectorisedView(),
+ }))
+ if networkDispatcher.networkPackets != 1 {
+ t.Fatalf("got networkDispatcher.networkPackets = %d, want = 1", networkDispatcher.networkPackets)
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index fc620c7d5..cb94cbea6 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // enable PACKET_FANOUT mode is the underlying socket is
- // of type AF_PACKET.
- const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ // Enable PACKET_FANOUT mode if the underlying socket is of type
+ // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
+ // prevent gvisor from receiving fragmented packets and the host does the
+ // reassembly on our behalf before delivering the fragments. This makes it
+ // hard to test fragmentation reassembly code in Netstack.
+ const fanoutType = unix.PACKET_FANOUT_HASH
fanoutArg := fanoutID | fanoutType<<16
if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
@@ -410,7 +413,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
}
var builder iovec.Builder
@@ -453,7 +456,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress(), pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 709f829c8..a87abc6d6 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -183,9 +183,8 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
defer c.cleanup()
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- }
+ var r stack.Route
+ r.ResolveWith(raddr)
// Build payload.
payload := buffer.NewView(plen)
@@ -220,7 +219,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(&r, gso, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -324,10 +323,9 @@ func TestPreserveSrcAddress(t *testing.T) {
defer c.cleanup()
// Set LocalLinkAddress in route to the value of the bridged address.
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- LocalLinkAddress: baddr,
- }
+ var r stack.Route
+ r.LocalLinkAddress = baddr
+ r.ResolveWith(raddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
@@ -336,7 +334,7 @@ func TestPreserveSrcAddress(t *testing.T) {
ReserveHeaderBytes: header.EthernetMinimumSize,
Data: buffer.VectorisedView{},
})
- if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 3e4afcdad..b511d3a31 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -51,7 +51,8 @@ func TestInjectableEndpointDispatch(t *testing.T) {
Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
})
pkt.TransportHeader().Push(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
+ var packetRoute stack.Route
+ packetRoute.RemoteAddress = dstIP
endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
@@ -73,7 +74,8 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
Data: buffer.NewView(0).ToVectorisedView(),
})
pkt.TransportHeader().Push(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
+ var packetRoute stack.Route
+ packetRoute.RemoteAddress = dstIP
endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
index 3922c2a04..9a1b0c0c2 100644
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -36,14 +36,14 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
// WritePacket implements stack.LinkEndpoint.WritePacket.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress(), r.LocalLinkAddress, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress(), pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 71fcb73e1..25c364391 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -55,7 +55,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network
// remote address from the perspective of the other end of the pipe
// (e.linked). Similarly, the remote address from the perspective of this
// endpoint is the local address on the other end.
- e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
}))
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 9b41d60d5..b7458b620 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -154,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
- newRoute := r.Clone()
- pkt.EgressRoute = &newRoute
+ pkt.EgressRoute = r
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
@@ -178,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
for pkt := pkts.Front(); pkt != nil; {
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
nxt := pkt.Next()
- // Since qdisc can hold onto a packet for long we should Clone
- // the route here to ensure it doesn't get released while the
- // packet is still in our queue.
- newRoute := pkt.EgressRoute.Clone()
- pkt.EgressRoute = &newRoute
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
index eb5abb906..45adcbccb 100644
--- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
+++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
@@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
+ s.EgressRoute.Acquire()
q.list.PushBack(s)
q.used++
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index a1e7018c8..5660418fa 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -204,7 +204,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
views := pkt.Views()
// Transmit the packet.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 22d5c97f1..dd2e1a125 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -260,9 +260,8 @@ func TestSimpleSend(t *testing.T) {
defer c.cleanup()
// Prepare route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
for iters := 1000; iters > 0; iters-- {
func() {
@@ -341,10 +340,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
// Set both remote and local link address in route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- LocalLinkAddress: newLocalLinkAddress,
- }
+ var r stack.Route
+ r.LocalLinkAddress = newLocalLinkAddress
+ r.ResolveWith(remoteLinkAddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
@@ -395,9 +393,8 @@ func TestFillTxQueue(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -444,9 +441,8 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
c.txq.rx.Flush()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -509,9 +505,8 @@ func TestFillTxMemory(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -557,9 +552,8 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 8d9a91020..1a2cc39eb 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -263,7 +263,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- if parse.ARP(pkt) {
+ if !parse.ARP(pkt) {
return
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 9a76bdba7..bfac358f4 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
// If the packet does not already have link layer header, and the route
// does not exist, we can't compute it. This is possibly a raw packet, tun
// device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" {
+ if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 {
return nil, false
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index b38aff0b8..9ebf31b78 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -7,12 +7,14 @@ go_test(
size = "small",
srcs = [
"ip_test.go",
+ "multicast_group_test.go",
],
deps = [
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index a738e9e1c..a25cba513 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -441,10 +441,9 @@ func (*testInterface) Promiscuous() bool {
}
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.NetProto = protocol
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index d8e4a3b54..429af69ee 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -18,7 +18,6 @@ go_template_instance(
go_library(
name = "fragmentation",
srcs = [
- "frag_heap.go",
"fragmentation.go",
"reassembler.go",
"reassembler_list.go",
@@ -38,7 +37,6 @@ go_test(
name = "fragmentation_test",
size = "small",
srcs = [
- "frag_heap_test.go",
"fragmentation_test.go",
"reassembler_test.go",
],
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
deleted file mode 100644
index 0b570d25a..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-type fragment struct {
- offset uint16
- vv buffer.VectorisedView
-}
-
-type fragHeap []fragment
-
-func (h *fragHeap) Len() int {
- return len(*h)
-}
-
-func (h *fragHeap) Less(i, j int) bool {
- return (*h)[i].offset < (*h)[j].offset
-}
-
-func (h *fragHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *fragHeap) Push(x interface{}) {
- *h = append(*h, x.(fragment))
-}
-
-func (h *fragHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[:n-1]
- return x
-}
-
-// reassamble empties the heap and returns a VectorisedView
-// containing a reassambled version of the fragments inside the heap.
-func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
- curr := heap.Pop(h).(fragment)
- views := curr.vv.Views()
- size := curr.vv.Size()
-
- if curr.offset != 0 {
- return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
- }
-
- for h.Len() > 0 {
- curr := heap.Pop(h).(fragment)
- if int(curr.offset) < size {
- curr.vv.TrimFront(size - int(curr.offset))
- } else if int(curr.offset) > size {
- return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
- }
- size += curr.vv.Size()
- views = append(views, curr.vv.Views()...)
- }
- return buffer.NewVectorisedView(size, views), nil
-}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
deleted file mode 100644
index 9ececcb9f..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap_test.go
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-var reassambleTestCases = []struct {
- comment string
- in []fragment
- want buffer.VectorisedView
-}{
- {
- comment: "Non-overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Non-overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Duplicated packets",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(1, "0"),
- },
- {
- comment: "Overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(2, "01")},
- {offset: 1, vv: vv(2, "12")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(2, "12")},
- {offset: 0, vv: vv(2, "01")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping subset in-order",
- in: []fragment{
- {offset: 0, vv: vv(3, "012")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(3, "012"),
- },
- {
- comment: "Overlapping subset out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(3, "012")},
- },
- want: vv(3, "012"),
- },
-}
-
-func TestReassamble(t *testing.T) {
- for _, c := range reassambleTestCases {
- t.Run(c.comment, func(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- for _, f := range c.in {
- heap.Push(&h, f)
- }
- got, err := h.reassemble()
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
- }
- })
- }
-}
-
-func TestReassambleFailsForNonZeroOffset(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when the first packet had offset != 0")
- }
-}
-
-func TestReassambleFailsForHoles(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
- heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when there was a hole in the packet")
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index c75ca7d71..1af87d713 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -46,9 +46,17 @@ const (
)
var (
- // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // ErrInvalidArgs indicates to the caller that an invalid argument was
// provided.
ErrInvalidArgs = errors.New("invalid args")
+
+ // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
+ // with another one.
+ ErrFragmentOverlap = errors.New("overlapping fragments")
+
+ // ErrFragmentConflict indicates that, during reassembly, some fragments are
+ // in conflict with one another.
+ ErrFragmentConflict = errors.New("conflicting fragments")
)
// FragmentID is the identifier for a fragment.
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 19f4920b3..9b20bb1d8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -15,9 +15,8 @@
package fragmentation
import (
- "container/heap"
- "fmt"
"math"
+ "sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,9 +25,11 @@ import (
)
type hole struct {
- first uint16
- last uint16
- deleted bool
+ first uint16
+ last uint16
+ filled bool
+ final bool
+ data buffer.View
}
type reassembler struct {
@@ -38,8 +39,7 @@ type reassembler struct {
proto uint8
mu sync.Mutex
holes []hole
- deleted int
- heap fragHeap
+ filled int
done bool
creationTime int64
pkt *stack.PacketBuffer
@@ -48,49 +48,94 @@ type reassembler struct {
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
- holes: make([]hole, 0, 16),
- heap: make(fragHeap, 0, 8),
creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
- first: 0,
- last: math.MaxUint16,
- deleted: false})
+ first: 0,
+ last: math.MaxUint16,
+ filled: false,
+ final: true,
+ })
return r
}
-// updateHoles updates the list of holes for an incoming fragment and
-// returns true iff the fragment filled at least part of an existing hole.
-func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
- used := false
- for i := range r.holes {
- if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
- continue
- }
- used = true
- r.deleted++
- r.holes[i].deleted = true
- if first > r.holes[i].first {
- r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
- }
- if last < r.holes[i].last && more {
- r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
- }
- }
- return used
-}
-
func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
- consumed := 0
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, 0, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, 0, nil
}
- if r.updateHoles(first, last, more) {
+
+ var holeFound bool
+ var consumed int
+ for i := range r.holes {
+ currentHole := &r.holes[i]
+
+ if last < currentHole.first || currentHole.last < first {
+ continue
+ }
+ // For IPv6, overlaps with an existing fragment are explicitly forbidden by
+ // RFC 8200 section 4.5:
+ // If any of the fragments being reassembled overlap with any other
+ // fragments being reassembled for the same packet, reassembly of that
+ // packet must be abandoned and all the fragments that have been received
+ // for that packet must be discarded, and no ICMP error messages should be
+ // sent.
+ //
+ // It is not explicitly forbidden for IPv4, but to keep parity with Linux we
+ // disallow it as well:
+ // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
+ if first < currentHole.first || currentHole.last < last {
+ // Incoming fragment only partially fits in the free hole.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap
+ }
+ if !more {
+ if !currentHole.final || currentHole.filled && currentHole.last != last {
+ // We have another final fragment, which does not perfectly overlap.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ }
+ }
+
+ holeFound = true
+ if currentHole.filled {
+ // Incoming fragment is a duplicate.
+ continue
+ }
+
+ // We are populating the current hole with the payload and creating a new
+ // hole for any unfilled ranges on either end.
+ if first > currentHole.first {
+ r.holes = append(r.holes, hole{
+ first: currentHole.first,
+ last: first - 1,
+ filled: false,
+ final: false,
+ })
+ }
+ if last < currentHole.last && more {
+ r.holes = append(r.holes, hole{
+ first: last + 1,
+ last: currentHole.last,
+ filled: false,
+ final: currentHole.final,
+ })
+ currentHole.final = false
+ }
+ v := pkt.Data.ToOwnedView()
+ consumed = v.Size()
+ r.size += consumed
+ // Update the current hole to precisely match the incoming fragment.
+ r.holes[i] = hole{
+ first: first,
+ last: last,
+ filled: true,
+ final: currentHole.final,
+ data: v,
+ }
+ r.filled++
// For IPv6, it is possible to have different Protocol values between
// fragments of a packet (because, unlike IPv4, the Protocol is not used to
// identify a fragment). In this case, only the Protocol of the first
@@ -103,21 +148,30 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
r.pkt = pkt
r.proto = proto
}
- vv := pkt.Data
- // We store the incoming packet only if it filled some holes.
- heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
- consumed = vv.Size()
- r.size += consumed
+
+ break
+ }
+ if !holeFound {
+ // Incoming fragment is beyond end.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
}
- // Check if all the holes have been deleted and we are ready to reassamble.
- if r.deleted < len(r.holes) {
+
+ // Check if all the holes have been filled and we are ready to reassemble.
+ if r.filled < len(r.holes) {
return buffer.VectorisedView{}, 0, false, consumed, nil
}
- res, err := r.heap.reassemble()
- if err != nil {
- return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
+
+ sort.Slice(r.holes, func(i, j int) bool {
+ return r.holes[i].first < r.holes[j].first
+ })
+
+ var size int
+ views := make([]buffer.View, 0, len(r.holes))
+ for _, hole := range r.holes {
+ views = append(views, hole.data)
+ size += hole.data.Size()
}
- return res, r.proto, true, consumed, nil
+ return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index a0a04a027..2ff03eeeb 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -16,92 +16,175 @@ package fragmentation
import (
"math"
- "reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-type updateHolesInput struct {
- first uint16
- last uint16
- more bool
+type processParams struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ wantDone bool
+ wantError error
}
-var holesTestCases = []struct {
- comment string
- in []updateHolesInput
- want []hole
-}{
- {
- comment: "No fragments. Expected holes: {[0 -> inf]}.",
- in: []updateHolesInput{},
- want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
- },
- {
- comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
- in: []updateHolesInput{{first: 0, last: 1, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: false},
+func TestReassemblerProcess(t *testing.T) {
+ const proto = 99
+
+ v := func(size int) buffer.View {
+ payload := buffer.NewView(size)
+ for i := 1; i < size; i++ {
+ payload[i] = uint8(i) * 3
+ }
+ return payload
+ }
+
+ pkt := func(size int) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v(size).ToVectorisedView(),
+ })
+ }
+
+ var tests = []struct {
+ name string
+ params []processParams
+ want []hole
+ }{
+ {
+ name: "No fragments",
+ params: nil,
+ want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
},
- },
- {
- comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
- {first: 3, last: math.MaxUint16, deleted: false},
+ {
+ name: "One fragment at beginning",
+ params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment at the end. Expected holes: {[0, 0]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
+ {
+ name: "One fragment in the middle",
+ params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: false, data: v(2)},
+ {first: 0, last: 0, filled: false, final: false},
+ {first: 3, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment completing a packet. Expected holes: {}.",
- in: []updateHolesInput{{first: 0, last: 1, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
+ {
+ name: "One fragment at the end",
+ params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: true, data: v(2)},
+ {first: 0, last: 0, filled: false},
+ },
},
- },
- {
- comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 1, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "One fragment completing a packet",
+ params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- },
- {
- comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 2, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "Two fragments completing a packet with a duplicate",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 3, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet with a partial duplicate",
+ params: []processParams{
+ {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
+ {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 3, filled: true, final: false, data: v(4)},
+ {first: 4, last: 5, filled: true, final: true, data: v(2)},
+ },
},
- },
-}
+ {
+ name: "Two overlapping fragments",
+ params: []processParams{
+ {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
+ {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true, final: false, data: v(11)},
+ {first: 11, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "Two final fragments with different ends",
+ params: []processParams{
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 10, last: 14, filled: true, final: true, data: v(5)},
+ {first: 0, last: 9, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate, with different ends",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ }
-func TestUpdateHoles(t *testing.T) {
- for _, c := range holesTestCases {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- for _, i := range c.in {
- r.updateHoles(i.first, i.last, i.more)
- }
- if !reflect.DeepEqual(r.holes, c.want) {
- t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ for _, param := range test.params {
+ _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
+ if done != param.wantDone || err != param.wantError {
+ t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
+ }
+ }
+ if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
index 6ca200b48..ca1247c1e 100644
--- a/pkg/tcpip/network/ip/BUILD
+++ b/pkg/tcpip/network/ip/BUILD
@@ -18,6 +18,7 @@ go_test(
srcs = ["generic_multicast_protocol_test.go"],
deps = [
":ip",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/faketime",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index 3113f4bbe..f2f0e069c 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -30,13 +30,42 @@ type hostState int
// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
// and MLDv1, IGMPv2 terminology will be used.
+//
+// ______________receive query______________
+// | |
+// | _____send or receive report_____ |
+// | | | |
+// V | V |
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+ |
+// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | -
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+
+// | ^ | ^ | ^ | ^
+// | | | | | | | |
+// ---------- ------- ---------- -------------
+// initialize new send inital fail to send send or receive
+// group membership report delayed report report
+//
+// Not shown in the diagram above, but any state may transition into the non
+// member state when a group is left.
const (
- // "'Non-Member' state, when the host does not belong to the group on
- // the interface. This is the initial state for all memberships on
+ // nonMember is the "'Non-Member' state, when the host does not belong to the
+ // group on the interface. This is the initial state for all memberships on
// all network interfaces; it requires no storage in the host."
//
// 'Non-Listener' is the MLDv1 term used to describe this state.
- _ hostState = iota
+ //
+ // This state is used to keep track of groups that have been joined locally,
+ // but without advertising the membership to the network.
+ nonMember hostState = iota
+
+ // pendingMember is a newly joined member that is waiting to successfully send
+ // the initial set of reports.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the initial report needs to be sent.
+ //
+ // MAY NOT transition to the idle member state from this state.
+ pendingMember
// delayingMember is the "'Delaying Member' state, when the host belongs to
// the group on the interface and has a report delay timer running for that
@@ -45,6 +74,16 @@ const (
// 'Delaying Listener' is the MLDv1 term used to describe this state.
delayingMember
+ // queuedDelayingMember is a delayingMember that failed to send a report after
+ // its delayed report timer fired. Hosts in this state are waiting to attempt
+ // retransmission of the delayed report.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the delayed report needs to be sent.
+ //
+ // May transition to idle member if a report is received for a group.
+ queuedDelayingMember
+
// idleMember is the "Idle Member" state, when the host belongs to the group
// on the interface and does not have a report delay timer running for that
// membership.
@@ -53,10 +92,24 @@ const (
idleMember
)
+func (s hostState) isDelayingMember() bool {
+ switch s {
+ case nonMember, pendingMember, idleMember:
+ return false
+ case delayingMember, queuedDelayingMember:
+ return true
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", s))
+ }
+}
+
// multicastGroupState holds the Generic Multicast Protocol state for a
// multicast group.
type multicastGroupState struct {
- // state contains the host's state for the group.
+ // joins is the number of times the group has been joined.
+ joins uint64
+
+ // state holds the host's state for the group.
state hostState
// lastToSendReport is true if we sent the last report for the group. It is
@@ -75,11 +128,53 @@ type multicastGroupState struct {
delayedReportJob *tcpip.Job
}
+// GenericMulticastProtocolOptions holds options for the generic multicast
+// protocol.
+type GenericMulticastProtocolOptions struct {
+ // Rand is the source of random numbers.
+ Rand *rand.Rand
+
+ // Clock is the clock used to create timers.
+ Clock tcpip.Clock
+
+ // Protocol is the implementation of the variant of multicast group protocol
+ // in use.
+ Protocol MulticastGroupProtocol
+
+ // MaxUnsolicitedReportDelay is the maximum amount of time to wait between
+ // transmitting unsolicited reports.
+ //
+ // Unsolicited reports are transmitted when a group is newly joined.
+ MaxUnsolicitedReportDelay time.Duration
+
+ // AllNodesAddress is a multicast address that all nodes on a network should
+ // be a member of.
+ //
+ // This address will not have the generic multicast protocol performed on it;
+ // it will be left in the non member/listener state, and packets will never
+ // be sent for it.
+ AllNodesAddress tcpip.Address
+}
+
// MulticastGroupProtocol is a multicast group protocol whose core state machine
// can be represented by GenericMulticastProtocolState.
type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
// SendReport sends a multicast report for the specified group address.
- SendReport(groupAddress tcpip.Address) *tcpip.Error
+ //
+ // Returns false if the caller should queue the report to be sent later. Note,
+ // returning false does not mean that the receiver hit an error.
+ SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
// SendLeave sends a multicast leave for the specified group address.
SendLeave(groupAddress tcpip.Address) *tcpip.Error
@@ -93,162 +188,198 @@ type MulticastGroupProtocol interface {
// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
//
+// Callers must synchronize accesses to the generic multicast protocol state;
+// GenericMulticastProtocolState obtains no locks in any of its methods. The
+// only exception to this is GenericMulticastProtocolState's timer/job callbacks
+// which will obtain the lock provided to the GenericMulticastProtocolState when
+// it is initialized.
+//
// GenericMulticastProtocolState.Init MUST be called before calling any of
// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
type GenericMulticastProtocolState struct {
- rand *rand.Rand
- clock tcpip.Clock
- protocol MulticastGroupProtocol
- maxUnsolicitedReportDelay time.Duration
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
- mu struct {
- sync.Mutex
+ opts GenericMulticastProtocolOptions
- // memberships holds group addresses and their associated state.
- memberships map[tcpip.Address]multicastGroupState
- }
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
+
+ // protocolMU is the mutex used to protect the protocol.
+ protocolMU *sync.RWMutex
}
// Init initializes the Generic Multicast Protocol state.
//
-// maxUnsolicitedReportDelay is the maximum time between sending unsolicited
-// reports after joining a group.
-func (g *GenericMulticastProtocolState) Init(rand *rand.Rand, clock tcpip.Clock, protocol MulticastGroupProtocol, maxUnsolicitedReportDelay time.Duration) {
- g.mu.Lock()
- defer g.mu.Unlock()
- g.rand = rand
- g.clock = clock
- g.protocol = protocol
- g.maxUnsolicitedReportDelay = maxUnsolicitedReportDelay
- g.mu.memberships = make(map[tcpip.Address]multicastGroupState)
+// Must only be called once for the lifetime of g; Init will panic if it is
+// called twice.
+//
+// The GenericMulticastProtocolState will only grab the lock when timers/jobs
+// fire.
+//
+// Note: the methods on opts.Protocol will always be called while protocolMU is
+// held.
+func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
+ if g.memberships != nil {
+ panic("attempted to initialize generic membership protocol state twice")
+ }
+
+ *g = GenericMulticastProtocolState{
+ opts: opts,
+ memberships: make(map[tcpip.Address]multicastGroupState),
+ protocolMU: protocolMU,
+ }
+}
+
+// MakeAllNonMemberLocked transitions all groups to the non-member state.
+//
+// The groups will still be considered joined locally.
+//
+// MUST be called when the multicast group protocol is disabled.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
}
-// JoinGroup handles joining a new group.
+// InitializeGroupsLocked initializes each group, as if they were newly joined
+// but without affecting the groups' join count.
+//
+// Must only be called after calling MakeAllNonMember as a group should not be
+// initialized while it is not in the non-member state.
//
-// Returns false if the group has already been joined.
-func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address) bool {
- g.mu.Lock()
- defer g.mu.Unlock()
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
- if _, ok := g.mu.memberships[groupAddress]; ok {
+ for groupAddress, info := range g.memberships {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// SendQueuedReportsLocked attempts to send reports for groups that failed to
+// send reports during their last attempt.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
+ for groupAddress, info := range g.memberships {
+ switch info.state {
+ case nonMember, delayingMember, idleMember:
+ case pendingMember:
+ // pendingMembers failed to send their initial unsolicited report so try
+ // to send the report and queue the extra unsolicited reports.
+ g.maybeSendInitialReportLocked(groupAddress, &info)
+ case queuedDelayingMember:
+ // queuedDelayingMembers failed to send their delayed reports so try to
+ // send the report and transition them to the idle state.
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", info.state))
+ }
+ g.memberships[groupAddress] = info
+ }
+}
+
+// JoinGroupLocked handles joining a new group.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
+ if info, ok := g.memberships[groupAddress]; ok {
// The group has already been joined.
- return false
+ info.joins++
+ g.memberships[groupAddress] = info
+ return
}
info := multicastGroupState{
- // There isn't a job scheduled currently, so it's just idle.
- state: idleMember,
- // Joining a group immediately sends a report.
- lastToSendReport: true,
- delayedReportJob: tcpip.NewJob(g.clock, &g.mu, func() {
- info, ok := g.mu.memberships[groupAddress]
+ // Since we just joined the group, its count is 1.
+ joins: 1,
+ // The state will be updated below, if required.
+ state: nonMember,
+ lastToSendReport: false,
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
+ info, ok := g.memberships[groupAddress]
if !ok {
panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
}
- info.lastToSendReport = g.protocol.SendReport(groupAddress) == nil
- info.state = idleMember
- g.mu.memberships[groupAddress] = info
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
}),
}
- // As per RFC 2236 section 3 page 5 (for IGMPv2),
- //
- // When a host joins a multicast group, it should immediately transmit an
- // unsolicited Version 2 Membership Report for that group" ... "it is
- // recommended that it be repeated".
- //
- // As per RFC 2710 section 4 page 6 (for MLDv1),
- //
- // When a node starts listening to a multicast address on an interface,
- // it should immediately transmit an unsolicited Report for that address
- // on that interface, in case it is the first listener on the link. To
- // cover the possibility of the initial Report being lost or damaged, it
- // is recommended that it be repeated once or twice after short delays
- // [Unsolicited Report Interval].
- //
- // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
- // unsolicited reports.
- info.lastToSendReport = g.protocol.SendReport(groupAddress) == nil
- g.setDelayTimerForAddressRLocked(groupAddress, &info, g.maxUnsolicitedReportDelay)
- g.mu.memberships[groupAddress] = info
- return true
+ if g.opts.Protocol.Enabled() {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ }
+
+ g.memberships[groupAddress] = info
}
-// LeaveGroup handles leaving the group.
-func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) {
- g.mu.Lock()
- defer g.mu.Unlock()
+// IsLocallyJoinedRLocked returns true if the group is locally joined.
+//
+// Precondition: g.protocolMU must be read locked.
+func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool {
+ _, ok := g.memberships[groupAddress]
+ return ok
+}
- info, ok := g.mu.memberships[groupAddress]
+// LeaveGroupLocked handles leaving the group.
+//
+// Returns false if the group is not currently joined.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool {
+ info, ok := g.memberships[groupAddress]
if !ok {
- return
+ return false
}
- info.delayedReportJob.Cancel()
- delete(g.mu.memberships, groupAddress)
- if info.lastToSendReport {
- // Okay to ignore the error here as if packet write failed, the multicast
- // routers will eventually drop our membership anyways. If the interface is
- // being disabled or removed, the generic multicast protocol's should be
- // cleared eventually.
- //
- // As per RFC 2236 section 3 page 5 (for IGMPv2),
- //
- // When a router receives a Report, it adds the group being reported to
- // the list of multicast group memberships on the network on which it
- // received the Report and sets the timer for the membership to the
- // [Group Membership Interval]. Repeated Reports refresh the timer. If
- // no Reports are received for a particular group before this timer has
- // expired, the router assumes that the group has no local members and
- // that it need not forward remotely-originated multicasts for that
- // group onto the attached network.
- //
- // As per RFC 2710 section 4 page 5 (for MLDv1),
- //
- // When a router receives a Report from a link, if the reported address
- // is not already present in the router's list of multicast address
- // having listeners on that link, the reported address is added to the
- // list, its timer is set to [Multicast Listener Interval], and its
- // appearance is made known to the router's multicast routing component.
- // If a Report is received for a multicast address that is already
- // present in the router's list, the timer for that address is reset to
- // [Multicast Listener Interval]. If an address's timer expires, it is
- // assumed that there are no longer any listeners for that address
- // present on the link, so it is deleted from the list and its
- // disappearance is made known to the multicast routing component.
- //
- // The requirement to send a leave message is also optional (it MAY be
- // skipped):
- //
- // As per RFC 2236 section 6 page 8 (for IGMPv2),
- //
- // "send leave" for the group on the interface. If the interface
- // state says the Querier is running IGMPv1, this action SHOULD be
- // skipped. If the flag saying we were the last host to report is
- // cleared, this action MAY be skipped. The Leave Message is sent to
- // the ALL-ROUTERS group (224.0.0.2).
- //
- // As per RFC 2710 section 5 page 8 (for MLDv1),
- //
- // "send done" for the address on the interface. If the flag saying
- // we were the last node to report is cleared, this action MAY be
- // skipped. The Done message is sent to the link-scope all-routers
- // address (FF02::2).
- _ = g.protocol.SendLeave(groupAddress)
+ if info.joins == 0 {
+ panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress))
+ }
+ info.joins--
+ if info.joins != 0 {
+ // If we still have outstanding joins, then do nothing further.
+ g.memberships[groupAddress] = info
+ return true
}
+
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ delete(g.memberships, groupAddress)
+ return true
}
-// HandleQuery handles a query message with the specified maximum response time.
+// HandleQueryLocked handles a query message with the specified maximum response
+// time.
//
// If the group address is unspecified, then reports will be scheduled for all
// joined groups.
//
// Report(s) will be scheduled to be sent after a random duration between 0 and
// the maximum response time.
-func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) {
- g.mu.Lock()
- defer g.mu.Unlock()
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
// As per RFC 2236 section 2.4 (for IGMPv2),
//
@@ -263,23 +394,26 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address,
// when sending a Multicast-Address-Specific Query.
if groupAddress.Unspecified() {
// This is a general query as the group address is unspecified.
- for groupAddress, info := range g.mu.memberships {
+ for groupAddress, info := range g.memberships {
g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
- } else if info, ok := g.mu.memberships[groupAddress]; ok {
+ } else if info, ok := g.memberships[groupAddress]; ok {
g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
-// HandleReport handles a report message.
+// HandleReportLocked handles a report message.
//
// If the report is for a joined group, any active delayed report will be
// cancelled and the host state for the group transitions to idle.
-func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) {
- g.mu.Lock()
- defer g.mu.Unlock()
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
// As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
//
@@ -293,18 +427,213 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address)
// multicast address while it has a timer running for that same address
// on that interface, it stops its timer and does not send a Report for
// that address, thus suppressing duplicate reports on the link.
- if info, ok := g.mu.memberships[groupAddress]; ok {
+ if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
info.delayedReportJob.Cancel()
info.lastToSendReport = false
info.state = idleMember
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
+ }
+}
+
+// initializeNewMemberLocked initializes a new group membership.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != nonMember {
+ panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ info.lastToSendReport = false
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ info.state = idleMember
+ return
+ }
+
+ info.state = pendingMember
+ g.maybeSendInitialReportLocked(groupAddress, info)
+}
+
+// maybeSendInitialReportLocked attempts to start transmission of the initial
+// set of reports after newly joining a group.
+//
+// Host must be in pending member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != pendingMember {
+ panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a host joins a multicast group, it should immediately transmit an
+ // unsolicited Version 2 Membership Report for that group" ... "it is
+ // recommended that it be repeated".
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node starts listening to a multicast address on an interface,
+ // it should immediately transmit an unsolicited Report for that address
+ // on that interface, in case it is the first listener on the link. To
+ // cover the possibility of the initial Report being lost or damaged, it
+ // is recommended that it be repeated once or twice after short delays
+ // [Unsolicited Report Interval].
+ //
+ // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
+ // unsolicited reports.
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ }
+}
+
+// maybeSendDelayedReportLocked attempts to send the delayed report.
+//
+// Host must be in pending, delaying or queued delaying member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if !info.state.isDelayingMember() {
+ panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ info.state = idleMember
+ } else {
+ info.state = queuedDelayingMember
}
}
+// maybeSendLeave attempts to send a leave message.
+func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // Okay to ignore the error here as if packet write failed, the multicast
+ // routers will eventually drop our membership anyways. If the interface is
+ // being disabled or removed, the generic multicast protocol's should be
+ // cleared eventually.
+ //
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a router receives a Report, it adds the group being reported to
+ // the list of multicast group memberships on the network on which it
+ // received the Report and sets the timer for the membership to the
+ // [Group Membership Interval]. Repeated Reports refresh the timer. If
+ // no Reports are received for a particular group before this timer has
+ // expired, the router assumes that the group has no local members and
+ // that it need not forward remotely-originated multicasts for that
+ // group onto the attached network.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // When a router receives a Report from a link, if the reported address
+ // is not already present in the router's list of multicast address
+ // having listeners on that link, the reported address is added to the
+ // list, its timer is set to [Multicast Listener Interval], and its
+ // appearance is made known to the router's multicast routing component.
+ // If a Report is received for a multicast address that is already
+ // present in the router's list, the timer for that address is reset to
+ // [Multicast Listener Interval]. If an address's timer expires, it is
+ // assumed that there are no longer any listeners for that address
+ // present on the link, so it is deleted from the list and its
+ // disappearance is made known to the multicast routing component.
+ //
+ // The requirement to send a leave message is also optional (it MAY be
+ // skipped):
+ //
+ // As per RFC 2236 section 6 page 8 (for IGMPv2),
+ //
+ // "send leave" for the group on the interface. If the interface
+ // state says the Querier is running IGMPv1, this action SHOULD be
+ // skipped. If the flag saying we were the last host to report is
+ // cleared, this action MAY be skipped. The Leave Message is sent to
+ // the ALL-ROUTERS group (224.0.0.2).
+ //
+ // As per RFC 2710 section 5 page 8 (for MLDv1),
+ //
+ // "send done" for the address on the interface. If the flag saying
+ // we were the last node to report is cleared, this action MAY be
+ // skipped. The Done message is sent to the link-scope all-routers
+ // address (FF02::2).
+ _ = g.opts.Protocol.SendLeave(groupAddress)
+}
+
+// transitionToNonMemberLocked transitions the given multicast group the the
+// non-member/listener state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state == nonMember {
+ return
+ }
+
+ info.delayedReportJob.Cancel()
+ g.maybeSendLeave(groupAddress, info.lastToSendReport)
+ info.lastToSendReport = false
+ info.state = nonMember
+}
+
// setDelayTimerForAddressRLocked sets timer to send a delay report.
//
-// Precondition: g.mu MUST be read locked.
+// Precondition: g.protocolMU MUST be read locked.
func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
+ if info.state == nonMember {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
// As per RFC 2236 section 3 page 3 (for IGMPv2),
//
// If a timer for the group is already unning, it is reset to the random
@@ -320,6 +649,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
// TODO: Reset the timer if time remaining is greater than maxResponseTime.
return
}
+
info.state = delayingMember
info.delayedReportJob.Cancel()
info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
@@ -342,5 +672,5 @@ func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime
if maxRespTime == 0 {
return 0
}
- return time.Duration(g.rand.Int63n(int64(maxRespTime)))
+ return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime)))
}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index eb48c0d51..85593f211 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/network/ip"
@@ -29,107 +30,305 @@ const (
addr1 = tcpip.Address("\x01")
addr2 = tcpip.Address("\x02")
addr3 = tcpip.Address("\x03")
+ addr4 = tcpip.Address("\x04")
+
+ maxUnsolicitedReportDelay = time.Second
)
var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
-type mockMulticastGroupProtocol struct {
+type mockMulticastGroupProtocolProtectedFields struct {
+ sync.RWMutex
+
+ genericMulticastGroup ip.GenericMulticastProtocolState
sendReportGroupAddrCount map[tcpip.Address]int
- sendLeaveGroupAddr tcpip.Address
+ sendLeaveGroupAddrCount map[tcpip.Address]int
+ makeQueuePackets bool
+ disabled bool
}
-func (m *mockMulticastGroupProtocol) init() {
- m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
- m.sendLeaveGroupAddr = ""
+type mockMulticastGroupProtocol struct {
+ t *testing.T
+
+ mu mockMulticastGroupProtocolProtectedFields
}
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error {
- m.sendReportGroupAddrCount[groupAddress]++
- return nil
+func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.initLocked()
+ opts.Protocol = m
+ m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
+}
+
+func (m *mockMulticastGroupProtocol) initLocked() {
+ m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.disabled = !v
+}
+
+func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.makeQueuePackets = v
+}
+
+func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.JoinGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleReportLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
+}
+
+func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) makeAllNonMember() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initializeGroups() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.InitializeGroupsLocked()
+}
+
+func (m *mockMulticastGroupProtocol) sendQueuedReports() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.SendQueuedReportsLocked()
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be read locked.
+func (m *mockMulticastGroupProtocol) Enabled() bool {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
+ }
+
+ return !m.mu.disabled
}
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+
+ m.mu.sendReportGroupAddrCount[groupAddress]++
+ return !m.mu.makeQueuePackets, nil
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
- m.sendLeaveGroupAddr = groupAddress
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+
+ m.mu.sendLeaveGroupAddrCount[groupAddress]++
return nil
}
-func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddr tcpip.Address) string {
- sendReportGroupAddressesMap := make(map[tcpip.Address]int)
+func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendReportGroupAddrCount := make(map[tcpip.Address]int)
for _, a := range sendReportGroupAddresses {
- sendReportGroupAddressesMap[a] = 1
+ sendReportGroupAddrCount[a] = 1
+ }
+
+ sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendLeaveGroupAddresses {
+ sendLeaveGroupAddrCount[a] = 1
}
- diff := cmp.Diff(mockMulticastGroupProtocol{
- sendReportGroupAddrCount: sendReportGroupAddressesMap,
- sendLeaveGroupAddr: sendLeaveGroupAddr,
- }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{}))
- mgp.init()
+ diff := cmp.Diff(
+ &mockMulticastGroupProtocol{
+ mu: mockMulticastGroupProtocolProtectedFields{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
+ },
+ m,
+ cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
+ // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
+ cmp.FilterPath(
+ func(p cmp.Path) bool {
+ switch p.Last().String() {
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
+ return true
+ }
+ return false
+ },
+ cmp.Ignore(),
+ ),
+ )
+ m.initLocked()
return diff
}
func TestJoinGroup(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendReports bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendReports: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendReports: false,
+ },
+ }
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
- clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(0)), clock, &mgp, maxUnsolicitedReportDelay)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
- // Joining a group should send a report immediately and another after
- // a random interval between 0 and the maximum unsolicited report delay.
- if !g.JoinGroup(addr1) {
- t.Errorf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(0)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ // Joining a group should send a report immediately and another after
+ // a random interval between 0 and the maximum unsolicited report delay.
+ mgp.joinGroup(test.addr)
+ if test.shouldSendReports {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
func TestLeaveGroup(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendMessages bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendMessages: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendMessages: false,
+ },
+ }
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
- clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(1)), clock, &mgp, maxUnsolicitedReportDelay)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(1)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ mgp.joinGroup(test.addr)
+ if test.shouldSendMessages {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
- // Leaving a group should send a leave report immediately and cancel any
- // delayed reports.
- g.LeaveGroup(addr1)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, addr1 /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ // Leaving a group should send a leave report immediately and cancel any
+ // delayed reports.
+ {
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if !mgp.leaveGroup(test.addr) {
+ t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
+ }
+ }
+ if test.shouldSendMessages {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
func TestHandleReport(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
-
tests := []struct {
name string
reportAddr tcpip.Address
@@ -151,46 +350,56 @@ func TestHandleReport(t *testing.T) {
expectReportsFor: []tcpip.Address{addr2},
},
{
- name: "Specified other",
+ name: "Specified all-nodes",
reportAddr: addr3,
expectReportsFor: []tcpip.Address{addr1, addr2},
},
+ {
+ name: "Specified other",
+ reportAddr: addr4,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(2)), clock, &mgp, maxUnsolicitedReportDelay)
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(2)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if !g.JoinGroup(addr2) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr2)
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a report for a group we have a timer scheduled for should
// cancel our delayed report timer for the group.
- g.HandleReport(test.reportAddr)
+ mgp.handleReport(test.reportAddr)
if len(test.expectReportsFor) != 0 {
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -198,8 +407,6 @@ func TestHandleReport(t *testing.T) {
}
func TestHandleQuery(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
-
tests := []struct {
name string
queryAddr tcpip.Address
@@ -225,70 +432,375 @@ func TestHandleQuery(t *testing.T) {
expectReportsFor: []tcpip.Address{addr1},
},
{
- name: "Specified other",
+ name: "Specified all-nodes",
queryAddr: addr3,
maxDelay: 3,
expectReportsFor: nil,
},
+ {
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectReportsFor: nil,
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(3)), clock, &mgp, maxUnsolicitedReportDelay)
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if !g.JoinGroup(addr2) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr2)
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a query should make us schedule a new delayed report if it
// is a query directed at us or a general query.
- g.HandleQuery(test.queryAddr, test.maxDelay)
+ mgp.handleQuery(test.queryAddr, test.maxDelay)
if len(test.expectReportsFor) != 0 {
clock.Advance(test.maxDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
}
}
-func TestDoubleJoinGroup(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+func TestJoinCount(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(4)), clock, &mgp, time.Second)
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: time.Second,
+ })
+
+ // Set the join count to 2 for a group.
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // Only the first join should trigger a report to be sent.
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should still be considered joined after leaving once.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
}
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // A leave report should only be sent once the join count reaches 0.
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving once more should actually remove us from the group.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should no longer be joined so we should not have anything to
+ // leave.
+ if mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestMakeAllNonMemberAndInitialize(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
- // Joining the same group twice should fail.
- if g.JoinGroup(addr1) {
- t.Errorf("got g.JoinGroup(%s) = true, want = false", addr1)
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should send the leave reports for each but still consider them locally
+ // joined.
+ mgp.makeAllNonMember()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ for _, group := range []tcpip.Address{addr1, addr2, addr3} {
+ if !mgp.isLocallyJoined(group) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
+ }
+ }
+
+ // Should send the initial set of unsolcited reports.
+ mgp.initializeGroups()
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// TestGroupStateNonMember tests that groups do not send packets when in the
+// non-member state, but are still considered locally joined.
+func TestGroupStateNonMember(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+ mgp.setEnabled(false)
+
+ // Joining groups should not send any reports.
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should not send any reports.
+ mgp.handleQuery(addr1, time.Nanosecond)
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Leaving groups should not send any leave messages.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestQueuedPackets(t *testing.T) {
+ clock := faketime.NewManualClock()
+ mgp := mockMulticastGroupProtocol{t: t}
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining should trigger a SendReport, but mgp should report that we did not
+ // send the packet.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report timer should have been cancelled since we did not send
+ // the initial report earlier.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send (we should be idle).
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query but mock being unable to send reports again.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to send reports again - we should have a packet queued to
+ // send.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query again, but mock being unable to send reports.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report should should transition us into the idle member state,
+ // even if we had a packet queued. We should no longer have any packets to
+ // send.
+ mgp.handleReport(addr1)
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // When we fail to send the initial set of reports, incoming reports should
+ // not affect a newly joined group's reports from being sent.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.handleReport(addr2)
+ // Attempting to send queued reports while still unable to send reports should
+ // not change the host state.
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 787399e08..3005973d7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -203,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -219,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -344,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -550,7 +550,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -619,11 +619,11 @@ func TestReceive(t *testing.T) {
view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadLen,
- NextHeader: 10,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
+ PayloadLength: payloadLen,
+ TransportProtocol: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -933,7 +933,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -993,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Create the outer IPv6 header.
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIPv6Addr,
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 20,
+ SrcAddr: outerSrcAddr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -1007,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp.SetIdent(0xdead)
icmp.SetSequence(0xbeef)
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: localIPv6Addr,
- DstAddr: remoteIPv6Addr,
- })
-
+ var extHdrs header.IPv6ExtHdrSerializer
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
- ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
- frag.Encode(&header.IPv6FragmentFields{
- NextHeader: 10,
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
FragmentOffset: *c.fragmentOffset,
M: true,
Identification: 0x12345678,
})
}
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
+ ExtensionHeaders: extHdrs,
+ })
+
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
@@ -1089,7 +1088,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
dataBuf := [dataLen]byte{1, 2, 3, 4}
data := dataBuf[:]
- ipv4Options := header.IPv4Options{0, 1, 0, 1}
+ ipv4Options := header.IPv4OptionsSerializer{
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ }
+
+ expectOptions := header.IPv4Options{
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ }
ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
@@ -1239,7 +1250,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding()
+ ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1262,7 +1273,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1272,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1284,7 +1295,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding()))
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
ip.Encode(&header.IPv4Fields{
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
@@ -1303,7 +1314,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1313,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1332,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1375,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ // NB: we're lying about transport protocol here to verify the raw
+ // fragment header bytes.
+ TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1410,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1445,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 488945226..8e392f86c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4PacketsReceived
+ received := stats.ICMP.V4.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.Echo.Increment()
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if !e.protocol.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return
@@ -379,7 +379,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4PacketsSent
+ sent := p.stack.Stats().ICMP.V4.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index c9bf117de..da88d65d1 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -16,7 +16,6 @@ package ipv4
import (
"fmt"
- "sync"
"sync/atomic"
"time"
@@ -51,6 +50,19 @@ const (
UnsolicitedReportIntervalMax = 10 * time.Second
)
+// IGMPOptions holds options for IGMP.
+type IGMPOptions struct {
+ // Enabled indicates whether IGMP will be performed.
+ //
+ // When enabled, IGMP may transmit IGMP report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // IGMP packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
// igmpState is the per-interface IGMP state.
@@ -60,6 +72,8 @@ type igmpState struct {
// The IPv4 endpoint this igmpState is for.
ep *endpoint
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+
// igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
// RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
// Membership Reports in response to its Queries, and will not pay
@@ -73,20 +87,23 @@ type igmpState struct {
// when false.
igmpV1Present uint32
- mu struct {
- sync.RWMutex
-
- genericMulticastProtocol ip.GenericMulticastProtocolState
+ // igmpV1Job is scheduled when this interface receives an IGMPv1 style
+ // message, upon expiration the igmpV1Present flag is cleared.
+ // igmpV1Job may not be nil once igmpState is initialized.
+ igmpV1Job *tcpip.Job
+}
- // igmpV1Job is scheduled when this interface receives an IGMPv1 style
- // message, upon expiration the igmpV1Present flag is cleared.
- // igmpV1Job may not be nil once igmpState is initialized.
- igmpV1Job *tcpip.Job
- }
+// Enabled implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) Enabled() bool {
+ // No need to perform IGMP on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
}
// SendReport implements ip.MulticastGroupProtocol.
-func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
igmpType := header.IGMPv2MembershipReport
if igmp.v1Present() {
igmpType = header.IGMPv1MembershipReport
@@ -95,6 +112,8 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
}
// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// As per RFC 2236 Section 6, Page 8: "If the interface state says the
// Querier is running IGMPv1, this action SHOULD be skipped. If the flag
@@ -103,22 +122,32 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
if igmp.v1Present() {
return nil
}
- return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ return err
}
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
+//
+// Must only be called once for the lifetime of igmp.
func (igmp *igmpState) init(ep *endpoint) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
igmp.ep = ep
- igmp.mu.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), igmp, UnsolicitedReportIntervalMax)
+ igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: igmp,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv4AllSystems,
+ })
igmp.igmpV1Present = igmpV1PresentDefault
- igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() {
+ igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
igmp.setV1Present(false)
})
}
+// handleIGMP handles an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
stats := igmp.ep.protocol.stack.Stats()
received := stats.IGMP.PacketsReceived
@@ -188,32 +217,34 @@ func (igmp *igmpState) setV1Present(v bool) {
}
}
+// handleMembershipQuery handles a membership query.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
-
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
- if maxRespTime == 0 {
- igmp.mu.igmpV1Job.Cancel()
- igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout)
+ if maxRespTime == 0 && igmp.Enabled() {
+ igmp.igmpV1Job.Cancel()
+ igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.setV1Present(true)
maxRespTime = v1MaxRespTime
}
- igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime)
+ igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime)
}
+// handleMembershipReport handles a membership report.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.HandleReport(groupAddress)
+ igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
}
-// writePacket assembles and sends an IGMP packet with the provided fields,
-// incrementing the provided stat counter on success.
-func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error {
+// writePacket assembles and sends an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
igmpData.SetType(igmpType)
igmpData.SetGroupAddress(groupAddress)
@@ -224,36 +255,37 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
Data: buffer.View(igmpData).ToVectorisedView(),
})
- // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
- // rather we should select an appropriate local address.
- r := stack.Route{
- LocalAddress: header.IPv4Any,
- RemoteAddress: destAddress,
+ addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return false, nil
}
- igmp.ep.addIPHeader(&r, pkt, stack.NetworkHeaderParams{
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
+ addressEndpoint = nil
+ igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.IGMPProtocolNumber,
TTL: header.IGMPTTL,
TOS: stack.DefaultTOS,
+ }, header.IPv4OptionsSerializer{
+ &header.IPv4SerializableRouterAlertOption{},
})
- // TODO(b/162198658): set the ROUTER_ALERT option when sending Host
- // Membership Reports.
- sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
- if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
- sent.Dropped.Increment()
- return err
+ sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
}
switch igmpType {
case header.IGMPv1MembershipReport:
- sent.V1MembershipReport.Increment()
+ sentStats.V1MembershipReport.Increment()
case header.IGMPv2MembershipReport:
- sent.V2MembershipReport.Increment()
+ sentStats.V2MembershipReport.Increment()
case header.IGMPLeaveGroup:
- sent.LeaveGroup.Increment()
+ sentStats.LeaveGroup.Increment()
default:
panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
}
- return nil
+ return true, nil
}
// joinGroup handles adding a new group to the membership map, setting up the
@@ -262,25 +294,52 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
//
// If the group already exists in the membership map, returns
// tcpip.ErrDuplicateAddress.
-func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
- // JoinGroup returns false if we have already joined the group.
- if !igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress) {
- return tcpip.ErrDuplicateAddress
- }
- return nil
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
+ return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Leave Group message
// if required.
//
-// If the group does not exist in the membership map, this function will
-// silently return.
-func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress)
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of IGMP, but remains
+// joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) softLeaveAll() {
+ igmp.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the IGMP state for each group that has
+// been joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) initializeAll() {
+ igmp.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) sendQueuedReports() {
+ igmp.genericMulticastProtocol.SendQueuedReportsLocked()
}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4873a336f..1ee573ac8 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -15,7 +15,6 @@
package ipv4_test
import (
- "fmt"
"testing"
"time"
@@ -30,25 +29,12 @@ import (
)
const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- // endpointAddr = tcpip.Address("\x0a\x00\x00\x02")
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ addr = tcpip.Address("\x0a\x00\x00\x01")
multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
nicID = 1
)
-var (
- // unsolicitedReportIntervalMaxTenthSec is the maximum amount of time the NIC
- // will wait before sending an unsolicited report after joining a multicast
- // group, in deciseconds.
- unsolicitedReportIntervalMaxTenthSec = func() uint8 {
- const decisecond = time.Second / 10
- if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
- panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
- }
- return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
- }()
-)
-
// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
// sent to the provided address with the passed fields set. Raises a t.Error if
// any field does not match.
@@ -57,7 +43,11 @@ func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
checker.IPv4(t, payload,
+ checker.SrcAddr(addr),
checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
checker.IGMP(
checker.IGMPType(igmpType),
checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
@@ -75,14 +65,15 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
- IGMPEnabled: igmpEnabled,
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
})},
Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
-
return e, s, clock
}
@@ -110,344 +101,14 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
})
}
-// TestIgmpDisabled tests that IGMP is not enabled with a default
-// stack.Options. This also tests that this NIC does not send the IGMP Join
-// Group for the All Hosts group it automatically joins when created.
-func TestIgmpDisabled(t *testing.T) {
- e, s, _ := createStack(t, false)
-
- // This NIC will join the All Hosts group when created. Verify that does not
- // send a report.
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got)
- }
- p, ok := e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
-
- // Test joining a specific group explicitly and verify that no reports are
- // sent.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4.ProtocolNumber, %d, %s) = %s", nicID, multicastAddr, err)
- }
-
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got)
- }
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
-
- // Inject a General Membership Query, which is an IGMP Membership Query with
- // a zeroed Group Address (IPv4Any) to verify that it does not reach the
- // handler.
- createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, unsolicitedReportIntervalMaxTenthSec, header.IPv4Any)
-
- if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 0 {
- t.Fatalf("got Membership Queries received = %d, want = 0", got)
- }
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
-}
-
-// TestIgmpReceivesIGMPMessages tests that the IGMP stack increments packet
-// counters when it receives properly formatted Membership Queries, Membership
-// Reports, and LeaveGroup Messages sent to this address. Note: test includes
-// IGMP header fields that are not explicitly tested in order to inject proper
-// IGMP packets.
-func TestIgmpReceivesIGMPMessages(t *testing.T) {
- tests := []struct {
- name string
- headerType header.IGMPType
- maxRespTime byte
- groupAddress tcpip.Address
- statCounter func(tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter
- }{
- {
- name: "General Membership Query",
- headerType: header.IGMPMembershipQuery,
- maxRespTime: unsolicitedReportIntervalMaxTenthSec,
- groupAddress: header.IPv4Any,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.MembershipQuery
- },
- },
- {
- name: "IGMPv1 Membership Report",
- headerType: header.IGMPv1MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.V1MembershipReport
- },
- },
- {
- name: "IGMPv2 Membership Report",
- headerType: header.IGMPv2MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.V2MembershipReport
- },
- },
- {
- name: "Leave Group",
- headerType: header.IGMPLeaveGroup,
- maxRespTime: 0,
- groupAddress: header.IPv4AllRoutersGroup,
- statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter {
- return stats.LeaveGroup
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, true)
-
- createAndInjectIGMPPacket(e, test.headerType, test.maxRespTime, test.groupAddress)
-
- if got := test.statCounter(s.Stats().IGMP.PacketsReceived).Value(); got != 1 {
- t.Fatalf("got %s received = %d, want = 1", test.name, got)
- }
- })
- }
-}
-
-// TestIgmpJoinGroup tests that when explicitly joining a multicast group, the
-// IGMP stack schedules and sends correct Membership Reports.
-func TestIgmpJoinGroup(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- // Test joining a specific address explicitly and verify a Membership Report
- // is sent immediately.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
-
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Verify the second Membership Report is sent after a random interval up to
- // the maximum unsolicited report interval.
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
-}
-
-// TestIgmpLeaveGroup tests that when leaving a previously joined multicast
-// group the IGMP enabled NIC sends the appropriate message.
-func TestIgmpLeaveGroup(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- // Join a group so that it can be left, validate the immediate Membership
- // Report is sent only to the multicast address joined.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Verify the second Membership Report is sent after a random interval up to
- // the maximum unsolicited report interval, and is sent to the multicast
- // address being joined.
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Now that there are no packets queued and none scheduled to be sent, leave
- // the group.
- if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Observe the Leave Group Message to verify that the Leave Group message is
- // sent to the All Routers group but that the message itself has the
- // multicast address being left.
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected LeaveGroup")
- }
- if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 1 {
- t.Fatalf("got LeaveGroup messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, header.IPv4AllRoutersGroup, header.IGMPLeaveGroup, 0, multicastAddr)
-}
-
-// TestIgmpJoinLeaveGroup tests that when leaving a previously joined multicast
-// group before the Unsolicited Report Interval cancels the second membership
-// report.
-func TestIgmpJoinLeaveGroup(t *testing.T) {
- _, s, clock := createStack(t, true)
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Verify that this NIC sent a Membership Report for only the group just
- // joined.
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
-
- if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Wait for the standard IGMP Unsolicited Report Interval duration before
- // verifying that the unsolicited Membership Report was sent after leaving
- // the group.
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
-}
-
-// TestIgmpMembershipQueryReport tests the handling of both incoming IGMP
-// Membership Queries and outgoing Membership Reports.
-func TestIgmpMembershipQueryReport(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
-
- // Inject a General Membership Query, which is an IGMP Membership Query with
- // a zeroed Group Address (IPv4Any) with the shortened Max Response Time.
- const maxRespTimeDS = 10
- createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, maxRespTimeDS, header.IPv4Any)
-
- p, ok = e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(header.DecisecondToDuration(maxRespTimeDS))
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 3 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 3", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
-}
-
-// TestIgmpMultipleHosts tests the handling of IGMP Leave when we are not the
-// most recent IGMP host to join a multicast network.
-func TestIgmpMultipleHosts(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
- if t.Failed() {
- t.FailNow()
- }
-
- // Inject another Host's Join Group message so that this host is not the
- // latest to send the report. Set Max Response Time to 0 for Membership
- // Reports.
- createAndInjectIGMPPacket(e, header.IGMPv2MembershipReport, 0, multicastAddr)
-
- if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // Wait to be sure that no Leave Group messages were sent up to the max
- // unsolicited report interval since it was not the last host to join this
- // group.
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 0 {
- t.Fatalf("got LeaveGroup messages sent = %d, want = 0", got)
- }
-}
-
// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is
// present on the network. The IGMP stack will then send IGMPv1 Membership
// reports for backwards compatibility.
func TestIgmpV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
@@ -498,3 +159,57 @@ func TestIgmpV1Present(t *testing.T) {
}
validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
}
+
+func TestSendQueuedIGMPReports(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Joining a group without an assigned address should queue IGMP packets; none
+ // should be sent without an assigned address.
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
+ }
+ reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
+ if got := reportStat.Value(); got != 0 {
+ t.Errorf("got reportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+
+ // The initial set of IGMP reports that were queued should be sent once an
+ // address is assigned.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+ if got := reportStat.Value(); got != 1 {
+ t.Errorf("got reportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ if got := reportStat.Value(); got != 2 {
+ t.Errorf("got reportStat.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should have no more packets to send after the initial set of unsolicited
+ // reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 7c759be9a..e9ff70d04 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -72,7 +72,6 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
- igmp igmpState
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -84,6 +83,7 @@ type endpoint struct {
sync.RWMutex
addressableEndpointState stack.AddressableEndpointState
+ igmp igmpState
}
}
@@ -94,8 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.igmp.init(e)
+ e.mu.igmp.init(e)
+ e.mu.Unlock()
return e
}
@@ -123,11 +125,22 @@ func (e *endpoint) Enable() *tcpip.Error {
// We have no need for the address endpoint.
ep.DecRef()
+ // Groups may have been joined while the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of IGMP when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.igmp.initializeAll()
+
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
- _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
- return err
+ if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err))
+ }
+
+ return nil
}
// Enabled implements stack.NetworkEndpoint.
@@ -159,19 +172,27 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.isEnabled() {
return
}
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
+ // Leave groups from the perspective of IGMP so that routers know that
+ // we are no longer interested in the group.
+ e.mu.igmp.softLeaveAll()
+
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -200,37 +221,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) {
hdrLen := header.IPv4MinimumSize
- var opts header.IPv4Options
- if params.Options != nil {
- var ok bool
- if opts, ok = params.Options.(header.IPv4Options); !ok {
- panic(fmt.Sprintf("want IPv4Options, got %T", params.Options))
- }
- hdrLen += opts.SizeWithPadding()
- if hdrLen > header.IPv4MaximumHeaderSize {
- // Since we have no way to report an error we must either panic or create
- // a packet which is different to what was requested. Choose panic as this
- // would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize))
- }
+ var optLen int
+ if options != nil {
+ optLen = int(options.Length())
+ }
+ hdrLen += optLen
+ if hdrLen > header.IPv4MaximumHeaderSize {
+ // Since we have no way to report an error we must either panic or create
+ // a packet which is different to what was requested. Choose panic as this
+ // would be a programming error that should be caught in testing.
+ panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize))
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := uint16(pkt.Size())
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
- id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
TotalLength: length,
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
Protocol: uint8(params.Protocol),
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- Options: opts,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ Options: options,
})
ip.SetChecksum(^ip.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -261,7 +279,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -349,7 +367,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
@@ -463,7 +481,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
@@ -706,10 +724,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
if p == header.IGMPProtocolNumber {
- if e.protocol.options.IGMPEnabled {
- e.igmp.handleIGMP(pkt)
- }
- // Nothing further to do with an IGMP packet, even if IGMP is not enabled.
+ e.mu.Lock()
+ e.mu.igmp.handleIGMP(pkt)
+ e.mu.Unlock()
return
}
if opts := h.Options(); len(opts) != 0 {
@@ -767,7 +784,12 @@ func (e *endpoint) Close() {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err == nil {
+ e.mu.igmp.sendQueuedReports()
+ }
+ return ep, err
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
@@ -790,34 +812,26 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
defer e.mu.Unlock()
loopback := e.nic.IsLoopback()
- addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
- })
- if addressEndpoint != nil {
- return addressEndpoint
- }
-
- if !allowTemp {
- return nil
- }
-
- addr := localAddr.WithPrefix()
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
- if err != nil {
- // AddAddress only returns an error if the address is already assigned,
- // but we just checked above if the address exists so we expect no error.
- panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
- }
- return addressEndpoint
+ }, allowTemp, tempPEB)
}
// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
@@ -836,40 +850,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
- if !header.IsV4MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
- }
-
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
- joinedGroup, err := e.mu.addressableEndpointState.JoinGroup(addr)
- if err == nil && joinedGroup && e.protocol.options.IGMPEnabled {
- _ = e.igmp.joinGroup(addr)
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
+ if !header.IsV4MulticastAddress(addr) {
+ return tcpip.ErrBadAddress
}
- return joinedGroup, err
+ e.mu.igmp.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
+ return e.leaveGroupLocked(addr)
+}
- leftGroup, err := e.mu.addressableEndpointState.LeaveGroup(addr)
- if err == nil && leftGroup && e.protocol.options.IGMPEnabled {
- e.igmp.leaveGroup(addr)
- }
-
- return leftGroup, err
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -1021,20 +1038,19 @@ func addressToUint32(addr tcpip.Address) uint32 {
return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
}
-// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number and a 32-bit number to
-// generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- a := addressToUint32(r.LocalAddress)
- b := addressToUint32(r.RemoteAddress)
+// hashRoute calculates a hash value for the given source/destination pair using
+// the addresses, transport protocol number and a 32-bit number to generate the
+// hash.
+func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ a := addressToUint32(srcAddr)
+ b := addressToUint32(dstAddr)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
// Options holds options to configure a new protocol.
type Options struct {
- // IGMPEnabled indicates whether incoming IGMP packets will be handled and if
- // this endpoint will transmit IGMP packets on IGMP related events.
- IGMPEnabled bool
+ // IGMP holds options for IGMP.
+ IGMP IGMPOptions
}
// NewProtocolWithOptions returns an IPv4 network protocol.
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 0acb7d5d1..1c4919b1e 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,9 +15,11 @@
package ipv4_test
import (
+ "bytes"
"context"
"encoding/hex"
"fmt"
+ "io/ioutil"
"math"
"net"
"testing"
@@ -103,105 +105,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
-// fields when options are supplied.
-func TestIPv4EncodeOptions(t *testing.T) {
- tests := []struct {
- name string
- options header.IPv4Options
- encodedOptions header.IPv4Options // reply should look like this
- wantIHL int
- }{
- {
- name: "valid no options",
- wantIHL: header.IPv4MinimumSize,
- },
- {
- name: "one byte options",
- options: header.IPv4Options{1},
- encodedOptions: header.IPv4Options{1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "two byte options",
- options: header.IPv4Options{1, 1},
- encodedOptions: header.IPv4Options{1, 1, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "three byte options",
- options: header.IPv4Options{1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "four byte options",
- options: header.IPv4Options{1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "five byte options",
- options: header.IPv4Options{1, 1, 1, 1, 1},
- encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 8,
- },
- {
- name: "thirty nine byte options",
- options: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39,
- },
- encodedOptions: header.IPv4Options{
- 1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24,
- 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 35, 36, 37, 38, 39, 0,
- },
- wantIHL: header.IPv4MinimumSize + 40,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength)
- hdr := buffer.NewPrependable(int(totalLen))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- // To check the padding works, poison the last byte of the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0xff
- ip.SetHeaderLength(0)
- }
- ip.Encode(&header.IPv4Fields{
- Options: test.options,
- })
- options := ip.Options()
- wantOptions := test.encodedOptions
- if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
- t.Errorf("got IHL of %d, want %d", got, want)
- }
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(wantOptions) == 0 && len(options) == 0 {
- return
- }
-
- if diff := cmp.Diff(wantOptions, options); diff != "" {
- t.Errorf("options mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
func TestForwarding(t *testing.T) {
const (
nicID1 = 1
@@ -453,14 +356,6 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{1, 1, 0, 0},
},
{
- name: "Check option padding",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 1},
- replyOptions: header.IPv4Options{1, 1, 1, 0},
- },
- {
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
maxTotalLength: ipv4.MaxTotalSize,
@@ -583,7 +478,7 @@ func TestIPv4Sanity(t *testing.T) {
68, 7, 5, 0,
// ^ ^ Linux points here which is wrong.
// | Not a multiple of 4
- 1, 2, 3,
+ 1, 2, 3, 0,
},
shouldFail: true,
expectErrorICMP: true,
@@ -967,8 +862,10 @@ func TestIPv4Sanity(t *testing.T) {
},
})
- paddedOptionLength := test.options.SizeWithPadding()
- ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
+ if len(test.options)%4 != 0 {
+ t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options))
+ }
+ ipHeaderLength := header.IPv4MinimumSize + len(test.options)
if ipHeaderLength > header.IPv4MaximumHeaderSize {
t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
@@ -987,11 +884,6 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
- // To check the padding works, poison the options space.
- if paddedOptionLength != len(test.options) {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0x01
- }
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
@@ -999,10 +891,19 @@ func TestIPv4Sanity(t *testing.T) {
TTL: test.TTL,
SrcAddr: remoteIPv4Addr,
DstAddr: ipv4Addr.Address,
- Options: test.options,
})
if test.headerLength != 0 {
ip.SetHeaderLength(test.headerLength)
+ } else {
+ // Set the calculated header length, since we may manually add options.
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ }
+ if len(test.options) != 0 {
+ // Copy options manually. We do not use Encode for options so we can
+ // verify malformed options with handcrafted payloads.
+ if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
+ t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
+ }
}
ip.SetChecksum(0)
ipHeaderChecksum := ip.CalculateChecksum()
@@ -1107,7 +1008,7 @@ func TestIPv4Sanity(t *testing.T) {
}
// If the IP options change size then the packet will change size, so
// some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - paddedOptionLength
+ sizeChange := len(test.replyOptions) - len(test.options)
checker.IPv4(t, replyIPHeader,
checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
@@ -2424,6 +2325,28 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
+ {
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: nil,
+ },
}
for _, test := range tests {
@@ -2487,18 +2410,26 @@ func TestReceiveFragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
+ const rcvSize = 65536 // Account for reassembled packets.
for i, expectedPayload := range test.expectedPayloads {
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ t.Fatalf("(i=%d) Read: %s", i, err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(expectedPayload),
+ Total: len(expectedPayload),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff)
}
- if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" {
+ if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" {
t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
}
}
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
@@ -2617,7 +2548,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2634,7 +2565,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 0ac24a6fb..afa45aefe 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -8,6 +8,7 @@ go_library(
"dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "mld.go",
"ndp.go",
],
visibility = ["//visibility:public"],
@@ -19,6 +20,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -49,3 +51,19 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "ipv6_x_test",
+ size = "small",
+ srcs = ["mld_test.go"],
+ deps = [
+ ":ipv6",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 386d98a29..6ee162713 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := e.protocol.stack.Stats().ICMP
- sent := stats.V6PacketsSent
- received := stats.V6PacketsReceived
+ sent := stats.V6.PacketsSent
+ received := stats.V6.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
- switch h.Type() {
+ switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
@@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na := header.NDPNeighborAdvert(packet.MessageBody())
// As per RFC 4861 section 7.2.4:
//
@@ -644,8 +644,39 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
+ case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ received.MulticastListenerQuery.Increment()
+ case header.ICMPv6MulticastListenerReport:
+ received.MulticastListenerReport.Increment()
+ case header.ICMPv6MulticastListenerDone:
+ received.MulticastListenerDone.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerReport:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerDone:
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
default:
- received.Invalid.Increment()
+ received.Unrecognized.Increment()
}
}
@@ -681,12 +712,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- stat := p.stack.Stats().ICMP.V6PacketsSent
+ stat := p.stack.Stats().ICMP.V6.PacketsSent
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
@@ -833,7 +864,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
defer route.Release()
stats := p.stack.Stats().ICMP
- sent := stats.V6PacketsSent
+ sent := stats.V6.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 9bc02d851..34a6a8446 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -149,10 +149,9 @@ func (*testInterface) Promiscuous() bool {
}
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.NetProto = protocol
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -271,6 +270,22 @@ func TestICMPCounts(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -280,11 +295,11 @@ func TestICMPCounts(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -301,7 +316,7 @@ func TestICMPCounts(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -413,6 +428,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -422,11 +453,11 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -443,7 +474,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -568,7 +599,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
return
}
- if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
+ if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr {
t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
}
@@ -821,11 +852,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
@@ -833,7 +864,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -898,11 +929,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1016,11 +1047,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(icmpSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1028,7 +1059,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1076,11 +1107,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1195,11 +1226,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(size + payloadSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
@@ -1207,7 +1238,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1413,11 +1444,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1455,11 +1486,11 @@ func TestPacketQueing(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1543,7 +1574,7 @@ func TestPacketQueing(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
@@ -1554,11 +1585,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1592,7 +1623,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1612,7 +1643,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1629,7 +1660,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1645,7 +1676,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1662,7 +1693,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1683,7 +1714,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1702,7 +1733,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1722,7 +1753,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1796,11 +1827,11 @@ func TestCallsToNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.source,
- DstAddr: test.destination,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
ep.HandlePacket(pkt)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 181c50cc7..f2018d073 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -19,6 +19,7 @@ import (
"encoding/binary"
"fmt"
"hash/fnv"
+ "math"
"sort"
"sync/atomic"
"time"
@@ -60,6 +61,108 @@ const (
buckets = 2048
)
+// policyTable is the default policy table defined in RFC 6724 section 2.1.
+//
+// A more human-readable version:
+//
+// Prefix Precedence Label
+// ::1/128 50 0
+// ::/0 40 1
+// ::ffff:0:0/96 35 4
+// 2002::/16 30 2
+// 2001::/32 5 5
+// fc00::/7 3 13
+// ::/96 1 3
+// fec0::/10 1 11
+// 3ffe::/16 1 12
+//
+// The table is sorted by prefix length so longest-prefix match can be easily
+// achieved.
+//
+// We willingly left out ::/96, fec0::/10 and 3ffe::/16 since those prefix
+// assignments are deprecated.
+//
+// As per RFC 4291 section 2.5.5.1 (for ::/96),
+//
+// The "IPv4-Compatible IPv6 address" is now deprecated because the
+// current IPv6 transition mechanisms no longer use these addresses.
+// New or updated implementations are not required to support this
+// address type.
+//
+// As per RFC 3879 section 4 (for fec0::/10),
+//
+// This document formally deprecates the IPv6 site-local unicast prefix
+// defined in [RFC3513], i.e., 1111111011 binary or FEC0::/10.
+//
+// As per RFC 3701 section 1 (for 3ffe::/16),
+//
+// As clearly stated in [TEST-NEW], the addresses for the 6bone are
+// temporary and will be reclaimed in the future. It further states
+// that all users of these addresses (within the 3FFE::/16 prefix) will
+// be required to renumber at some time in the future.
+//
+// and section 2,
+//
+// Thus after the pTLA allocation cutoff date January 1, 2004, it is
+// REQUIRED that no new 6bone 3FFE pTLAs be allocated.
+//
+// MUST NOT BE MODIFIED.
+var policyTable = [...]struct {
+ subnet tcpip.Subnet
+
+ label uint8
+}{
+ // ::1/128
+ {
+ subnet: header.IPv6Loopback.WithPrefix().Subnet(),
+ label: 0,
+ },
+ // ::ffff:0:0/96
+ {
+ subnet: header.IPv4MappedIPv6Subnet,
+ label: 4,
+ },
+ // 2001::/32 (Teredo prefix as per RFC 4380 section 2.6).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 32,
+ }.Subnet(),
+ label: 5,
+ },
+ // 2002::/16 (6to4 prefix as per RFC 3056 section 2).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 16,
+ }.Subnet(),
+ label: 2,
+ },
+ // fc00::/7 (Unique local addresses as per RFC 4193 section 3.1).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 7,
+ }.Subnet(),
+ label: 13,
+ },
+ // ::/0
+ {
+ subnet: header.IPv6EmptySubnet,
+ label: 1,
+ },
+}
+
+func getLabel(addr tcpip.Address) uint8 {
+ for _, p := range policyTable {
+ if p.subnet.Contains(addr) {
+ return p.label
+ }
+ }
+
+ panic(fmt.Sprintf("should have a label for address = %s", addr))
+}
+
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -85,6 +188,7 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
+ mld mldState
}
}
@@ -120,6 +224,45 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// onAddressAssignedLocked handles an address being assigned.
+//
+// Precondition: e.mu must be exclusively locked.
+func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, ...
+ //
+ // If we just completed DAD for a link-local address, then attempt to send any
+ // queued MLD reports. Note, we may have sent reports already for some of the
+ // groups before we had a valid link-local address to use as the source for
+ // the MLD messages, but that was only so that MLD snooping switches are aware
+ // of our membership to groups - routers would not have handled those reports.
+ //
+ // As per RFC 3590 section 4,
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ if header.IsV6LinkLocalAddress(addr) {
+ e.mu.mld.sendQueuedReports()
+ }
+}
+
// InvalidateDefaultRouter implements stack.NDPEndpoint.
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
@@ -226,6 +369,12 @@ func (e *endpoint) Enable() *tcpip.Error {
return nil
}
+ // Groups may have been joined when the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of MLD when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.mld.initializeAll()
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -243,8 +392,10 @@ func (e *endpoint) Enable() *tcpip.Error {
// (NDP NS) messages may be sent to the All-Nodes multicast group if the
// source address of the NDP NS is the unspecified address, as per RFC 4861
// section 7.2.4.
- if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
- return err
+ if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err))
}
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
@@ -253,7 +404,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
var err *tcpip.Error
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
return true
@@ -324,7 +475,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.Enabled() {
return
}
@@ -333,9 +484,17 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
+
+ // Leave groups from the perspective of MLD so that routers know that
+ // we are no longer interested in the group.
+ e.mu.mld.softLeaveAll()
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -343,7 +502,7 @@ func (e *endpoint) disableLocked() {
// Precondition: e.mu must be write locked.
func (e *endpoint) stopDADForPermanentAddressesLocked() {
// Stop DAD for all the tentative unicast addresses.
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.GetKind() != stack.PermanentTentative {
return true
}
@@ -375,19 +534,27 @@ func (e *endpoint) MTU() uint32 {
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
+ // TODO(gvisor.dev/issues/5035): The maximum header length returned here does
+ // not open the possibility for the caller to know about size required for
+ // extension headers.
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) {
+ extHdrsLen := extensionHeaders.Length()
+ length := pkt.Size() + extensionHeaders.Length()
+ if length > math.MaxUint16 {
+ panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16))
+ }
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(params.Protocol),
- HopLimit: params.TTL,
- TrafficClass: params.TOS,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ PayloadLength: uint16(length),
+ TransportProtocol: params.Protocol,
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ ExtensionHeaders: extensionHeaders,
})
pkt.NetworkProtocolNumber = ProtocolNumber
}
@@ -442,7 +609,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -531,7 +698,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r, pb, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */)
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
@@ -1163,11 +1330,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
return addressEndpoint, nil
}
- snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
- return nil, err
- }
-
addressEndpoint.SetKind(stack.PermanentTentative)
if e.Enabled() {
@@ -1176,6 +1338,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
}
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
+ }
+
return addressEndpoint, nil
}
@@ -1221,7 +1390,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ // The endpoint may have already left the multicast group.
+ if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1244,7 +1414,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
//
// Precondition: e.mu must be read or write locked.
func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
- return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+ return e.mu.addressableEndpointState.GetAddress(localAddr)
}
// MainAddress implements stack.AddressableEndpoint.
@@ -1276,6 +1446,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow
return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
+// getLinkLocalAddressRLocked returns a link-local address from the primary list
+// of addresses, if one is available.
+//
+// See stack.PrimaryEndpointBehavior for more details about the primary list.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
+ var linkLocalAddr tcpip.Address
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.IsAssigned(false /* allowExpired */) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ linkLocalAddr = addr
+ return false
+ }
+ }
+ return true
+ })
+ return linkLocalAddr
+}
+
// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
// but with locking requirements.
//
@@ -1285,7 +1475,11 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// RFC 6724 section 5.
type addrCandidate struct {
addressEndpoint stack.AddressEndpoint
+ addr tcpip.Address
scope header.IPv6AddressScope
+
+ label uint8
+ matchingPrefix uint8
}
if len(remoteAddr) == 0 {
@@ -1295,10 +1489,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// Create a candidate set of available addresses we can potentially use as a
// source address.
var cs []addrCandidate
- e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
// If r is not valid for outgoing connections, it is not a valid endpoint.
if !addressEndpoint.IsAssigned(allowExpired) {
- return
+ return true
}
addr := addressEndpoint.AddressWithPrefix().Address
@@ -1312,8 +1506,13 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
cs = append(cs, addrCandidate{
addressEndpoint: addressEndpoint,
+ addr: addr,
scope: scope,
+ label: getLabel(addr),
+ matchingPrefix: remoteAddr.MatchingPrefix(addr),
})
+
+ return true
})
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
@@ -1322,18 +1521,20 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
}
+ remoteLabel := getLabel(remoteAddr)
+
// Sort the addresses as per RFC 6724 section 5 rules 1-3.
//
- // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ // TODO(b/146021396): Implement rules 4, 5 of RFC 6724 section 5.
sort.Slice(cs, func(i, j int) bool {
sa := cs[i]
sb := cs[j]
// Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sa.addr == remoteAddr {
return true
}
- if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sb.addr == remoteAddr {
return false
}
@@ -1350,11 +1551,29 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
return sbDep
}
+ // Prefer matching label as per RFC 6724 section 5 rule 6.
+ if sa, sb := sa.label == remoteLabel, sb.label == remoteLabel; sa != sb {
+ if sa {
+ return true
+ }
+ if sb {
+ return false
+ }
+ }
+
// Prefer temporary addresses as per RFC 6724 section 5 rule 7.
if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
return saTemp
}
+ // Use longest matching prefix as per RFC 6724 section 5 rule 8.
+ if sa.matchingPrefix > sb.matchingPrefix {
+ return true
+ }
+ if sb.matchingPrefix > sa.matchingPrefix {
+ return false
+ }
+
// sa and sb are equal, return the endpoint that is closest to the front of
// the primary endpoint list.
return i < j
@@ -1386,28 +1605,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mu.mld.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -1472,16 +1706,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp = ndpState{
- ep: e,
- configs: p.options.NDPConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- e.mu.ndp.initializeTempAddrState()
+ e.mu.ndp.init(e)
+ e.mu.mld.init(e)
+ e.mu.Unlock()
p.mu.Lock()
defer p.mu.Unlock()
@@ -1638,6 +1867,9 @@ type Options struct {
// seed that is too small would reduce randomness and increase predictability,
// defeating the purpose of temporary SLAAC addresses.
TempIIDSeed []byte
+
+ // MLD holds options for MLD.
+ MLD MLDOptions
}
// NewProtocolWithOptions returns an IPv6 network protocol.
@@ -1699,24 +1931,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
fragPkt.NetworkProtocolNumber = ProtocolNumber
originalIPHeadersLength := len(originalIPHeaders)
- fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+
+ s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ M: more,
+ Identification: id,
+ }}
+
+ fragmentIPHeadersLength := originalIPHeadersLength + s.Length()
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
- fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
}
- fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
- fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
- fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
- fragmentHeader.Encode(&header.IPv6FragmentFields{
- M: more,
- FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
- Identification: id,
- NextHeader: uint8(transportProto),
- })
+ nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:])
+
+ fragmentIPHeaders.SetNextHeader(nextHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
return fragPkt, more
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 781bf7900..360025b20 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,8 +15,10 @@
package ipv6
import (
+ "bytes"
"encoding/hex"
"fmt"
+ "io/ioutil"
"math"
"net"
"testing"
@@ -69,18 +71,18 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
if got := stats.NeighborAdvert.Value(); got != want {
t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
@@ -127,11 +129,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -844,13 +846,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
},
}
+ const mtu = header.IPv6MinimumMTU
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
+ e := channel.New(1, mtu, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -915,10 +918,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
- NextHeader: ipv6NextHdr,
- HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: dstAddr,
+ // We're lying about transport protocol here to be able to generate
+ // raw extension headers from the test definitions.
+ TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr),
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -977,17 +982,24 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if got := stats.Value(); got != 1 {
t.Errorf("got UDP Rx Packets = %d, want = 1", got)
}
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, mtu, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("Read(nil): %s", err)
+ t.Fatalf("Read: %s", err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(udpPayload),
+ Total: len(udpPayload),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
}
- if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
// Should not have any more UDP packets.
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, mtu, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
@@ -1007,9 +1019,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
- udpPayload3Length = 127
- udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
- fragmentExtHdrLen = 8
+ udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
+ udpMaximumSizeMinus15 = header.UDPMaximumSize - 15
+ fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
routingExtHdrLen = 8
@@ -1353,14 +1366,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+65520,
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[:65520],
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
),
},
@@ -1369,14 +1382,17 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+ // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ udpMaximumSizeMinus15 & 0xff,
+ 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[65520:],
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
),
},
@@ -1384,6 +1400,47 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
{
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ (udpMaximumSizeMinus15 & 0xff) + 1,
+ 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -1902,10 +1959,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
+ // We're lying about transport protocol here so that we can generate
+ // raw extension headers for the tests.
+ TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr),
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
@@ -1920,18 +1979,20 @@ func TestReceiveIPv6Fragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
+ const rcvSize = 65536 // Account for reassembled packets.
for i, p := range test.expectedPayloads {
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ _, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ t.Fatalf("(i=%d) Read: %s", i, err)
}
- if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ if diff := cmp.Diff(p, buf.Bytes()); diff != "" {
t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
}
}
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
@@ -1950,7 +2011,7 @@ func TestInvalidIPv6Fragments(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -1969,14 +2030,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 9,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0 >> 3,
M: true,
Identification: ident,
@@ -1996,14 +2056,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
M: false,
Identification: ident,
@@ -2044,10 +2103,9 @@ func TestInvalidIPv6Fragments(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
-
- fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2109,7 +2167,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -2123,14 +2181,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2145,14 +2202,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2161,14 +2217,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2183,14 +2238,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2205,14 +2259,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2221,14 +2274,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2243,14 +2295,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2259,14 +2310,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2305,10 +2355,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2464,7 +2515,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2481,7 +2532,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
@@ -2949,11 +3000,11 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: test.TTL,
- SrcAddr: remoteIPv6Addr1,
- DstAddr: remoteIPv6Addr2,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
})
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
new file mode 100644
index 000000000..e8d1e7a79
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -0,0 +1,262 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited MLD reports.
+ //
+ // Obtained from RFC 2710 Section 7.10.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// MLDOptions holds options for MLD.
+type MLDOptions struct {
+ // Enabled indicates whether MLD will be performed.
+ //
+ // When enabled, MLD may transmit MLD report and done messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // MLD packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*mldState)(nil)
+
+// mldState is the per-interface MLD state.
+//
+// mldState.init MUST be called to initialize the MLD state.
+type mldState struct {
+ // The IPv6 endpoint this mldState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+func (mld *mldState) Enabled() bool {
+ // No need to perform MLD on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ return err
+}
+
+// init sets up an mldState struct, and is required to be called before using
+// a new mldState.
+//
+// Must only be called once for the lifetime of mld.
+func (mld *mldState) init(ep *endpoint) {
+ mld.ep = ep
+ mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: mld,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv6AllNodesMulticastAddress,
+ })
+}
+
+// handleMulticastListenerQuery handles a query message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+}
+
+// handleMulticastListenerReport handles a report message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress())
+}
+
+// joinGroup handles joining a new group and sending and scheduling the required
+// messages.
+//
+// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
+ return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Done message, if
+// required.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of MLD, but remains
+// joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) softLeaveAll() {
+ mld.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the MLD state for each group that has
+// been joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) initializeAll() {
+ mld.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) sendQueuedReports() {
+ mld.genericMulticastProtocol.SendQueuedReportsLocked()
+}
+
+// writePacket assembles and sends an MLD packet.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
+ sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ var mldStat *tcpip.StatCounter
+ switch mldType {
+ case header.ICMPv6MulticastListenerReport:
+ mldStat = sentStats.MulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ mldStat = sentStats.MulticastListenerDone
+ default:
+ panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
+ }
+
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
+ icmp.SetType(mldType)
+ header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert
+ // option in a Hop-by-Hop Options header.
+ //
+ // However, this would cause problems with Duplicate Address Detection with
+ // the first address as MLD snooping switches may not send multicast traffic
+ // that DAD depends on to the node performing DAD without the MLD report, as
+ // documented in RFC 4816:
+ //
+ // Note that when a node joins a multicast address, it typically sends a
+ // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810]
+ // for the multicast address. In the case of Duplicate Address
+ // Detection, the MLD report message is required in order to inform MLD-
+ // snooping switches, rather than routers, to forward multicast packets.
+ // In the above description, the delay for joining the multicast address
+ // thus means delaying transmission of the corresponding MLD report
+ // message. Since the MLD specifications do not request a random delay
+ // to avoid race conditions, just delaying Neighbor Solicitation would
+ // cause congestion by the MLD report messages. The congestion would
+ // then prevent the MLD-snooping switches from working correctly and, as
+ // a result, prevent Duplicate Address Detection from working. The
+ // requirement to include the delay for the MLD report in this case
+ // avoids this scenario. [RFC3590] also talks about some interaction
+ // issues between Duplicate Address Detection and MLD, and specifies
+ // which source address should be used for the MLD report in this case.
+ //
+ // As per RFC 3590 section 4, we should still send out MLD reports with an
+ // unspecified source address if we do not have an assigned link-local
+ // address to use as the source address to ensure DAD works as expected on
+ // networks with MLD snooping switches:
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ localAddress := mld.ep.getLinkLocalAddressRLocked()
+ if len(localAddress) == 0 {
+ localAddress = header.IPv6Any
+ }
+
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+
+ extensionHeaders := header.IPv6ExtHdrSerializer{
+ header.IPv6SerializableHopByHopExtHdr{
+ &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
+ },
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+
+ mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.MLDHopLimit,
+ }, extensionHeaders)
+ if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
+ }
+ mldStat.Increment()
+ return localAddress != header.IPv6Any, nil
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
new file mode 100644
index 000000000..e2778b656
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -0,0 +1,297 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+)
+
+var (
+ linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
+ globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
+)
+
+func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
+ t.Helper()
+
+ checker.IPv6WithExtHdr(t, p,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(localAddress),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(mldType, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // The stack will join an address's solicited node multicast address when
+ // an address is added. An MLD report message should be sent for the
+ // solicited-node group.
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+
+ // The stack will leave an address's solicited node multicast address when
+ // an address is removed. An MLD done message should be sent for the
+ // solicited-node group.
+ if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a done message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
+ }
+}
+
+func TestSendQueuedMLDReports(t *testing.T) {
+ const (
+ nicID = 1
+ maxReports = 2
+ )
+
+ tests := []struct {
+ name string
+ dadTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD Disabled",
+ dadTransmits: 0,
+ retransmitTimer: 0,
+ },
+ {
+ name: "DAD Enabled",
+ dadTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dadTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ Clock: clock,
+ })
+
+ // Allow space for an extra packet so we can observe packets that were
+ // unexpectedly sent.
+ e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ resolveDAD := func(addr, snmc tcpip.Address) {
+ clock.Advance(dadResolutionTime)
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected DAD packet")
+ } else {
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr),
+ checker.NDPNSOptions(nil),
+ ))
+ }
+ }
+
+ var reportCounter uint64
+ reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ var doneCounter uint64
+ doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+
+ // Joining a group without an assigned address should send an MLD report
+ // with the unspecified address.
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalMulticastAddr)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a global address should not send reports for the already joined
+ // group since we should only send queued reports when a link-local
+ // addres sis assigned.
+ //
+ // Note, we will still expect to send a report for the global address's
+ // solicited node address from the unspecified address as per RFC 3590
+ // section 4.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
+ }
+ if dadResolutionTime != 0 {
+ // Reports should not be sent when the address resolves.
+ resolveDAD(globalAddr, globalAddrSNMC)
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ }
+ // Leave the group since we don't care about the global address's
+ // solicited node multicast group membership.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
+ }
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a link-local address should send a report for its solicited node
+ // address and globalMulticastAddr.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ }
+ if dadResolutionTime != 0 {
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+ resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
+ }
+
+ // We expect two batches of reports to be sent (1 batch when the
+ // link-local address is assigned, and another after the maximum
+ // unsolicited report interval.
+ for i := 0; i < 2; i++ {
+ // We expect reports to be sent (one for globalMulticastAddr and another
+ // for linkLocalAddrSNMC).
+ reportCounter += maxReports
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+
+ addrs := map[tcpip.Address]bool{
+ globalMulticastAddr: false,
+ linkLocalAddrSNMC: false,
+ }
+ for _ = range addrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
+ }
+
+ addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
+ if seen, ok := addrs[addr]; !ok {
+ t.Fatalf("got unexpected packet destined to %s", addr)
+ } else if seen {
+ t.Fatalf("got another packet destined to %s", addr)
+ }
+
+ addrs[addr] = true
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
+
+ clock.Advance(ipv6.UnsolicitedReportIntervalMax)
+ }
+ }
+
+ // Should not send any more reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index c138358af..d515eb622 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -20,6 +20,7 @@ import (
"math/rand"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
// The IPv6 endpoint this ndpState is for.
ep *endpoint
@@ -471,17 +475,8 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- rtrSolicit struct {
- // The timer used to send the next router solicitation message.
- timer tcpip.Timer
-
- // Used to let the Router Solicitation timer know that it has been stopped.
- //
- // Must only be read from or written to while protected by the lock of
- // the IPv6 endpoint this ndpState is associated with. MUST be set when the
- // timer is set.
- done *bool
- }
+ // The job used to send the next router solicitation message.
+ rtrSolicitJob *tcpip.Job
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -507,7 +502,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer tcpip.Timer
+ job *tcpip.Job
// Used to let the DAD timer know that it has been stopped.
//
@@ -652,92 +647,69 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
+ ndp.ep.onAddressAssignedLocked(addr)
return nil
}
- var done bool
- var timer tcpip.Timer
- // We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the IPv6 endpoint's lock. This is effectively
- // the same as starting a goroutine but we use a timer that fires immediately
- // so we can reset it for the next DAD iteration.
- timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
- ndp.ep.mu.Lock()
- defer ndp.ep.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the IPv6 endpoint lock and stopped
- // DAD before this function obtained the NIC lock. Simply return here and
- // do nothing further.
- return
- }
+ state := dadState{
+ job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.dad[addr]
+ if !ok {
+ panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- // The endpoint should still be marked as tentative since we are still
- // performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- dadDone := remaining == 0
-
- var err *tcpip.Error
- if !dadDone {
- // Use the unspecified address as the source address when performing DAD.
- addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
-
- // Do not hold the lock when sending packets which may be a long running
- // task or may block link address resolution. We know this is safe
- // because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the IPv6 endpoint. Note,
- // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
- // the address was removed.
- ndp.ep.mu.Unlock()
- err = ndp.sendDADPacket(addr, addressEndpoint)
- ndp.ep.mu.Lock()
- addressEndpoint.DecRef()
- }
+ dadDone := remaining == 0
- if done {
- // If we reach this point, it means that DAD was stopped after we released
- // the IPv6 endpoint's read lock and before we obtained the write lock.
- return
- }
+ var err *tcpip.Error
+ if !dadDone {
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ }
- if dadDone {
- // DAD has resolved.
- addressEndpoint.SetKind(stack.Permanent)
- } else if err == nil {
- // DAD is not done and we had no errors when sending the last NDP NS,
- // schedule the next DAD timer.
- remaining--
- timer.Reset(ndp.configs.RetransmitTimer)
- return
- }
+ if dadDone {
+ // DAD has resolved.
+ addressEndpoint.SetKind(stack.Permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ state.job.Schedule(ndp.configs.RetransmitTimer)
+ return
+ }
- // At this point we know that either DAD is done or we hit an error sending
- // the last NDP NS. Either way, clean up addr's DAD state and let the
- // integrator know DAD has completed.
- delete(ndp.dad, addr)
+ // At this point we know that either DAD is done or we hit an error
+ // sending the last NDP NS. Either way, clean up addr's DAD state and let
+ // the integrator know DAD has completed.
+ delete(ndp.dad, addr)
- if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
- }
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ }
- // If DAD resolved for a stable SLAAC address, attempt generation of a
- // temporary SLAAC address.
- if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the generation
- // of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
- }
- })
+ if dadDone {
+ if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the
+ // generation of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
- ndp.dad[addr] = dadState{
- timer: timer,
- done: &done,
+ ndp.ep.onAddressAssignedLocked(addr)
+ }
+ }),
}
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ state.job.Schedule(0)
+ ndp.dad[addr] = state
+
return nil
}
@@ -745,55 +717,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-//
-// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer r.Release()
-
- // Route should resolve immediately since snmc is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is not
- // locked, the NIC could have been disabled or removed by another goroutine.
- if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
- return err
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
- }
-
- icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
- icmpData.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(addr)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(icmpData).ToVectorisedView(),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
return err
}
sent.NeighborSolicit.Increment()
-
return nil
}
@@ -812,14 +760,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
return
}
- if dad.timer != nil {
- dad.timer.Stop()
- dad.timer = nil
-
- *dad.done = true
- dad.done = nil
- }
-
+ dad.job.Cancel()
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
@@ -1859,7 +1800,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicit.timer != nil {
+ if ndp.rtrSolicitJob != nil {
// We are already soliciting routers.
return
}
@@ -1876,56 +1817,14 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- var done bool
- ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
- ndp.ep.mu.Lock()
- if done {
- // If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the IPv6 endpoint lock and stopped
- // solicitations. Simply return here and do nothing further.
- ndp.ep.mu.Unlock()
- return
- }
-
+ ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
- if addressEndpoint == nil {
- // Incase this ends up creating a new temporary address, we need to hold
- // onto the endpoint until a route is obtained. If we decrement the
- // reference count before obtaing a route, the address's resources would
- // be released and attempting to obtain a route after would fail. Once a
- // route is obtainted, it is safe to decrement the reference count since
- // obtaining a route increments the address's reference count.
- addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
- }
- ndp.ep.mu.Unlock()
-
- localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
- addressEndpoint.DecRef()
- if err != nil {
- return
- }
- defer r.Release()
-
- // Route should resolve immediately since
- // header.IPv6AllRoutersMulticastAddress is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is
- // not locked, the IPv6 endpoint could have been disabled or removed by
- // another goroutine.
- if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
- return
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
+ localAddr := header.IPv6Any
+ if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1936,30 +1835,31 @@ func (ndp *ndpState) startSolicitingRouters() {
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
var optsSerializer header.NDPOptionsSerializer
- if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ linkAddress := ndp.ep.nic.LinkAddress()
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) {
optsSerializer = header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ header.NDPSourceLinkLayerAddressOption(linkAddress),
}
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs := header.NDPRouterSolicit(icmpData.MessageBody())
rs.Options().Serialize(optsSerializer)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
@@ -1969,21 +1869,12 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.ep.mu.Lock()
- if done || remaining == 0 {
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
- } else if ndp.rtrSolicit.timer != nil {
- // Note, we need to explicitly check to make sure that
- // the timer field is not nil because if it was nil but
- // we still reached this point, then we know the IPv6 endpoint
- // was requested to stop soliciting routers so we don't
- // need to send the next Router Solicitation message.
- ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ if remaining != 0 {
+ ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval)
}
- ndp.ep.mu.Unlock()
})
+ ndp.rtrSolicitJob.Schedule(delay)
}
// stopSolicitingRouters stops soliciting routers. If routers are not currently
@@ -1991,22 +1882,28 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicit.timer == nil {
+ if ndp.rtrSolicitJob == nil {
// Nothing to do.
return
}
- *ndp.rtrSolicit.done = true
- ndp.rtrSolicit.timer.Stop()
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
+ ndp.rtrSolicitJob.Cancel()
+ ndp.rtrSolicitJob = nil
}
-// initializeTempAddrState initializes state related to temporary SLAAC
-// addresses.
-func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
+func (ndp *ndpState) init(ep *endpoint) {
+ if ndp.dad != nil {
+ panic("attempted to initialize NDP state twice")
+ }
+ ndp.ep = ep
+ ndp.configs = ep.protocol.options.NDPConfigs
+ ndp.dad = make(map[tcpip.Address]dadState)
+ ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
+ ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
+
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 9fbd0d336..7ddb19c00 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -205,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -213,14 +213,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -311,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -319,14 +319,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -591,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(nicAddr)
opts := ns.Options()
opts.Serialize(test.nsOpts)
@@ -599,14 +599,14 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -672,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(test.nsSrc)
@@ -681,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
})
e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -777,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -785,14 +785,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -890,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -898,14 +898,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -979,29 +979,25 @@ func TestNDPValidation(t *testing.T) {
}
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
+ var extHdrs header.IPv6ExtHdrSerializer
if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
}
+ extHdrsLen := extHdrs.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen,
Data: payload.ToVectorisedView(),
})
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(payload) + extHdrsLen),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: hopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ ExtensionHeaders: extHdrs,
})
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
ep.HandlePacket(pkt)
}
@@ -1122,7 +1118,7 @@ func TestNDPValidation(t *testing.T) {
s.SetForwarding(ProtocolNumber, true)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -1346,19 +1342,19 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
+ copy(pkt.MessageBody(), test.ndpPayload)
payloadLength := hdr.UsedLength()
pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
rxRA := stats.RouterAdvert
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
new file mode 100644
index 000000000..05d98a0a5
--- /dev/null
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -0,0 +1,1261 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+ ipv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+
+ ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
+ ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
+ ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
+ ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
+ ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
+
+ igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
+ igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
+ igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
+ igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
+ mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
+ mldReport = uint8(header.ICMPv6MulticastListenerReport)
+ mldDone = uint8(header.ICMPv6MulticastListenerDone)
+
+ maxUnsolicitedReports = 2
+)
+
+var (
+ // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
+ // NIC will wait before sending an unsolicited report after joining a
+ // multicast group, in deciseconds.
+ unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
+ const decisecond = time.Second / 10
+ if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
+ panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
+ }
+ return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
+ }()
+
+ ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr)
+)
+
+// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
+// sent to the provided address with the passed fields set.
+func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv6WithExtHdr(t, payload,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(ipv6Addr),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set.
+func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.SrcAddr(ipv4Addr),
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(header.IGMPType(igmpType)),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
+ s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
+ return e, s, clock
+}
+
+func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ igmpEnabled := v4 && mgpEnabled
+ mldEnabled := !v4 && mgpEnabled
+
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: mldEnabled,
+ },
+ }),
+ },
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err)
+ }
+
+ return s, clock
+}
+
+// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
+// when it is created with an IPv6 address.
+//
+// To not interfere with tests, checkInitialIPv6Groups will leave the added
+// address's solicited node multicast group so that the tests can all assume
+// the NIC has not joined any IPv6 groups.
+func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
+ t.Helper()
+
+ stats := s.Stats().ICMP.V6.PacketsSent
+
+ reportCounter++
+ if got := stats.MulticastListenerReport.Value(); got != reportCounter {
+ t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
+ }
+
+ // Leave the group to not affect the tests. This is fine since we are not
+ // testing DAD or the solicited node address specifically.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
+ }
+ leaveCounter++
+ if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
+ t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ return reportCounter, leaveCounter
+}
+
+// createAndInjectIGMPPacket creates and injects an IGMP packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: header.IGMPTTL,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(header.IGMPType(igmpType))
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// createAndInjectMLDPacket creates and injects an MLD packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
+ icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ buf := buffer.NewView(header.IPv6MinimumSize + icmpSize)
+
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ HopLimit: header.MLDHopLimit,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ icmp := header.ICMPv6(buf[header.IPv6MinimumSize:])
+ icmp.SetType(header.ICMPv6Type(mldType))
+ mld := header.MLD(icmp.MessageBody())
+ mld.SetMaximumResponseDelay(uint16(maxRespDelay))
+ mld.SetMulticastAddress(groupAddress)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+
+ e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestMGPDisabled tests that the multicast group protocol is not enabled by
+// default.
+func TestMGPDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
+
+ // This NIC may join multicast groups when it is enabled but since MGP is
+ // disabled, no reports should be sent.
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
+ }
+
+ // Inject a general query message. This should only trigger a report to be
+ // sent if the MGP was enabled.
+ test.rxQuery(e)
+ if got := test.receivedQueryStat(s).Value(); got != 1 {
+ t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
+ }
+ })
+ }
+}
+
+func TestMGPReceiveCounters(t *testing.T) {
+ tests := []struct {
+ name string
+ headerType uint8
+ maxRespTime byte
+ groupAddress tcpip.Address
+ statCounter func(*stack.Stack) *tcpip.StatCounter
+ rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
+ }{
+ {
+ name: "IGMP Membership Query",
+ headerType: igmpMembershipQuery,
+ maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
+ groupAddress: header.IPv4Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv1 Membership Report",
+ headerType: igmpv1MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V1MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv2 Membership Report",
+ headerType: igmpv2MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V2MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMP Leave Group",
+ headerType: igmpLeaveGroup,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllRoutersGroup,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.LeaveGroup
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "MLD Query",
+ headerType: mldQuery,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Report",
+ headerType: mldReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Done",
+ headerType: mldDone,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
+
+ test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
+ if got := test.statCounter(s).Value(); got != 1 {
+ t.Fatalf("got %s received = %d, want = 1", test.name, got)
+ }
+ })
+ }
+}
+
+// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
+// stack schedules and sends correct Membership Reports.
+func TestMGPJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ // Test joining a specific address explicitly and verify a Report is sent
+ // immediately.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Verify the second report is sent by the maximum unsolicited response
+ // interval.
+ p, ok := e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPLeaveGroup tests that when leaving a previously joined multicast
+// group the stack sends a leave/done message.
+func TestMGPLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ if got := test.sentReportStat(s).Value(); got != reportCounter {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving the group should trigger an leave/done message to be sent.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ leaveCounter++
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a leave message to be sent")
+ } else {
+ test.validateLeave(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that a report is sent in response to query
+// messages.
+func TestMGPQueryMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ multicastAddr tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Unspecified",
+ multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
+ expectReport: true,
+ },
+ {
+ name: "Specified",
+ multicastAddr: test.multicastAddr,
+ expectReport: true,
+ },
+ {
+ name: "Specified other address",
+ multicastAddr: func() tcpip.Address {
+ addrBytes := []byte(test.multicastAddr)
+ addrBytes[len(addrBytes)-1]++
+ return tcpip.Address(addrBytes)
+ }(),
+ expectReport: false,
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ for i := 0; i < maxUnsolicitedReports; i++ {
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected %d-th report message to be sent", i)
+ } else {
+ test.validateReport(t, p)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should not send any more packets until a query.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ // Receive a query message which should trigger a report to be sent at
+ // some time before the maximum response time if the report is
+ // targeted at the host.
+ const maxRespTime = 100
+ test.rxQuery(e, maxRespTime, subTest.multicastAddr)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p.Pkt)
+ }
+
+ if subTest.expectReport {
+ clock.Advance(test.maxRespTimeToDuration(maxRespTime))
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that no further reports or leave/done messages
+// are sent after receiving a report.
+func TestMGPReportMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ rxReport func(*channel.Endpoint)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Receiving a report for a group we joined should cancel any further
+ // reports.
+ test.rxReport(e)
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("sent unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving a group after getting a report should not send a leave/done
+ // message.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ clock.Advance(time.Hour)
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+func TestMGPWithNICLifecycle(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddrs []tcpip.Address
+ finalMulticastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
+ validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
+ getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
+ finalMulticastAddr: ipv4MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
+ t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
+ }
+ addr := header.IGMP(ipv4.Payload()).GroupAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
+ finalMulticastAddr: ipv6MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, addr, mldReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+
+ ipv6HeaderIter := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var transport header.IPv6RawPayloadHeader
+ for {
+ h, done, err := ipv6HeaderIter.Next()
+ if err != nil {
+ t.Fatalf("ipv6HeaderIter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ }
+ if t, ok := h.(header.IPv6RawPayloadHeader); ok {
+ transport = t
+ break
+ }
+ }
+
+ if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
+ t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
+ }
+ icmpv6 := header.ICMPv6(transport.Buf.ToView())
+ if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
+ t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
+ }
+ addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ sentReportStat := test.sentReportStat(s)
+ for _, a := range test.multicastAddrs {
+ if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected a report message to be sent for %s", a)
+ } else {
+ test.validateReport(t, p, a)
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leave messages should be sent for the joined groups when the NIC is
+ // disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ sentLeaveStat := test.sentLeaveStat(s)
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+
+ test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Reports should be sent for the joined groups when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter += uint64(len(test.multicastAddrs))
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) report message to be sent", i)
+ }
+
+ test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Joining/leaving a group while disabled should not send any messages.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ for i, _ := range test.multicastAddrs {
+ if _, ok := e.Read(); !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+ }
+ for _, a := range test.multicastAddrs {
+ if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
+ }
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
+ }
+ }
+ if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
+ }
+
+ // A report should only be sent for the group we last joined after
+ // enabling the NIC since the original groups were all left.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
+// performed on loopback interfaces since they have no neighbours.
+func TestMGPDisabledOnLoopback(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
+
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 51d428049..4777163cd 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -44,6 +44,7 @@ import (
"bufio"
"fmt"
"log"
+ "math"
"math/rand"
"net"
"os"
@@ -200,7 +201,7 @@ func main() {
// connection from its side.
wq.EventRegister(&waitEntry, waiter.EventIn)
for {
- v, _, err := ep.Read(nil)
+ _, err := ep.Read(os.Stdout, math.MaxUint16, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrClosedForReceive {
break
@@ -213,8 +214,6 @@ func main() {
log.Fatal("Read() failed:", err)
}
-
- os.Stdout.Write(v)
}
wq.EventUnregister(&waitEntry)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 1c2afd554..a80fa0474 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -20,8 +20,10 @@
package main
import (
+ "bytes"
"flag"
"log"
+ "math"
"math/rand"
"net"
"os"
@@ -54,7 +56,8 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
defer wq.EventUnregister(&waitEntry)
for {
- v, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ _, err := ep.Read(&buf, math.MaxUint16, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
@@ -64,7 +67,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
return
}
- ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ ep.Write(tcpip.SlicePayload(buf.Bytes()), tcpip.WriteOptions{})
}
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 1b1188ee5..f3ad40fdf 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -16,13 +16,13 @@ package tcpip
import (
"sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// SocketOptionsHandler holds methods that help define endpoint specific
-// behavior for socket options.
-// These must be implemented by endpoints to:
-// - Get notified when socket level options are set.
-// - Provide endpoint specific socket options.
+// behavior for socket level socket options. These must be implemented by
+// endpoints to get notified when socket level options are set.
type SocketOptionsHandler interface {
// OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint.
OnReuseAddressSet(v bool)
@@ -33,9 +33,21 @@ type SocketOptionsHandler interface {
// OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint.
OnKeepAliveSet(v bool)
- // IsListening is invoked to fetch SO_ACCEPTCONN option value for an
- // endpoint. It is used to indicate if the socket is a listening socket.
- IsListening() bool
+ // OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint.
+ // Note that v will be the inverse of TCP_NODELAY option.
+ OnDelayOptionSet(v bool)
+
+ // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint.
+ OnCorkOptionSet(v bool)
+
+ // LastError is invoked when SO_ERROR is read for an endpoint.
+ LastError() *Error
+
+ // UpdateLastError updates the endpoint specific last error field.
+ UpdateLastError(err *Error)
+
+ // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
+ HasNIC(v int32) bool
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -53,11 +65,27 @@ func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {}
// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet.
func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {}
-// IsListening implements SocketOptionsHandler.IsListening.
-func (*DefaultSocketOptionsHandler) IsListening() bool { return false }
+// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet.
+func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
+
+// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet.
+func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
+
+// LastError implements SocketOptionsHandler.LastError.
+func (*DefaultSocketOptionsHandler) LastError() *Error {
+ return nil
+}
+
+// UpdateLastError implements SocketOptionsHandler.UpdateLastError.
+func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {}
+
+// HasNIC implements SocketOptionsHandler.HasNIC.
+func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
+ return false
+}
-// SocketOptions contains all the variables which store values for SOL_SOCKET
-// level options.
+// SocketOptions contains all the variables which store values for SOL_SOCKET,
+// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
// +stateify savable
type SocketOptions struct {
@@ -65,29 +93,89 @@ type SocketOptions struct {
// These fields are accessed and modified using atomic operations.
- // broadcastEnabled determines whether datagram sockets are allowed to send
- // packets to a broadcast address.
+ // broadcastEnabled determines whether datagram sockets are allowed to
+ // send packets to a broadcast address.
broadcastEnabled uint32
- // passCredEnabled determines whether SCM_CREDENTIALS socket control messages
- // are enabled.
+ // passCredEnabled determines whether SCM_CREDENTIALS socket control
+ // messages are enabled.
passCredEnabled uint32
// noChecksumEnabled determines whether UDP checksum is disabled while
// transmitting for this socket.
noChecksumEnabled uint32
- // reuseAddressEnabled determines whether Bind() should allow reuse of local
- // address.
+ // reuseAddressEnabled determines whether Bind() should allow reuse of
+ // local address.
reuseAddressEnabled uint32
- // reusePortEnabled determines whether to permit multiple sockets to be bound
- // to an identical socket address.
+ // reusePortEnabled determines whether to permit multiple sockets to be
+ // bound to an identical socket address.
reusePortEnabled uint32
// keepAliveEnabled determines whether TCP keepalive is enabled for this
// socket.
keepAliveEnabled uint32
+
+ // multicastLoopEnabled determines whether multicast packets sent over a
+ // non-loopback interface will be looped back.
+ multicastLoopEnabled uint32
+
+ // receiveTOSEnabled is used to specify if the TOS ancillary message is
+ // passed with incoming packets.
+ receiveTOSEnabled uint32
+
+ // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary
+ // message is passed with incoming packets.
+ receiveTClassEnabled uint32
+
+ // receivePacketInfoEnabled is used to specify if more inforamtion is
+ // provided with incoming packets such as interface index and address.
+ receivePacketInfoEnabled uint32
+
+ // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
+ // being written have an IP header and the endpoint should not attach an IP
+ // header.
+ hdrIncludedEnabled uint32
+
+ // v6OnlyEnabled is used to determine whether an IPv6 socket is to be
+ // restricted to sending and receiving IPv6 packets only.
+ v6OnlyEnabled uint32
+
+ // quickAckEnabled is used to represent the value of TCP_QUICKACK option.
+ // It currently does not have any effect on the TCP endpoint.
+ quickAckEnabled uint32
+
+ // delayOptionEnabled is used to specify if data should be sent out immediately
+ // by the transport protocol. For TCP, it determines if the Nagle algorithm
+ // is on or off.
+ delayOptionEnabled uint32
+
+ // corkOptionEnabled is used to specify if data should be held until segments
+ // are full by the TCP transport protocol.
+ corkOptionEnabled uint32
+
+ // receiveOriginalDstAddress is used to specify if the original destination of
+ // the incoming packet should be returned as an ancillary message.
+ receiveOriginalDstAddress uint32
+
+ // recvErrEnabled determines whether extended reliable error message passing
+ // is enabled.
+ recvErrEnabled uint32
+
+ // errQueue is the per-socket error queue. It is protected by errQueueMu.
+ errQueueMu sync.Mutex `state:"nosave"`
+ errQueue sockErrorList
+
+ // bindToDevice determines the device to which the socket is bound.
+ bindToDevice int32
+
+ // mu protects the access to the below fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // linger determines the amount of time the socket should linger before
+ // close. We currently implement this option for TCP socket only.
+ linger LingerOption
}
// InitHandler initializes the handler. This must be called before using the
@@ -104,6 +192,11 @@ func storeAtomicBool(addr *uint32, v bool) {
atomic.StoreUint32(addr, val)
}
+// SetLastError sets the last error for a socket.
+func (so *SocketOptions) SetLastError(err *Error) {
+ so.handler.UpdateLastError(err)
+}
+
// GetBroadcast gets value for SO_BROADCAST option.
func (so *SocketOptions) GetBroadcast() bool {
return atomic.LoadUint32(&so.broadcastEnabled) != 0
@@ -167,8 +260,261 @@ func (so *SocketOptions) SetKeepAlive(v bool) {
so.handler.OnKeepAliveSet(v)
}
-// GetAcceptConn gets value for SO_ACCEPTCONN option.
-func (so *SocketOptions) GetAcceptConn() bool {
- // This option is completely endpoint dependent and unsettable.
- return so.handler.IsListening()
+// GetMulticastLoop gets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) GetMulticastLoop() bool {
+ return atomic.LoadUint32(&so.multicastLoopEnabled) != 0
+}
+
+// SetMulticastLoop sets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) SetMulticastLoop(v bool) {
+ storeAtomicBool(&so.multicastLoopEnabled, v)
+}
+
+// GetReceiveTOS gets value for IP_RECVTOS option.
+func (so *SocketOptions) GetReceiveTOS() bool {
+ return atomic.LoadUint32(&so.receiveTOSEnabled) != 0
+}
+
+// SetReceiveTOS sets value for IP_RECVTOS option.
+func (so *SocketOptions) SetReceiveTOS(v bool) {
+ storeAtomicBool(&so.receiveTOSEnabled, v)
+}
+
+// GetReceiveTClass gets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) GetReceiveTClass() bool {
+ return atomic.LoadUint32(&so.receiveTClassEnabled) != 0
+}
+
+// SetReceiveTClass sets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) SetReceiveTClass(v bool) {
+ storeAtomicBool(&so.receiveTClassEnabled, v)
+}
+
+// GetReceivePacketInfo gets value for IP_PKTINFO option.
+func (so *SocketOptions) GetReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0
+}
+
+// SetReceivePacketInfo sets value for IP_PKTINFO option.
+func (so *SocketOptions) SetReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receivePacketInfoEnabled, v)
+}
+
+// GetHeaderIncluded gets value for IP_HDRINCL option.
+func (so *SocketOptions) GetHeaderIncluded() bool {
+ return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
+}
+
+// SetHeaderIncluded sets value for IP_HDRINCL option.
+func (so *SocketOptions) SetHeaderIncluded(v bool) {
+ storeAtomicBool(&so.hdrIncludedEnabled, v)
+}
+
+// GetV6Only gets value for IPV6_V6ONLY option.
+func (so *SocketOptions) GetV6Only() bool {
+ return atomic.LoadUint32(&so.v6OnlyEnabled) != 0
+}
+
+// SetV6Only sets value for IPV6_V6ONLY option.
+//
+// Preconditions: the backing TCP or UDP endpoint must be in initial state.
+func (so *SocketOptions) SetV6Only(v bool) {
+ storeAtomicBool(&so.v6OnlyEnabled, v)
+}
+
+// GetQuickAck gets value for TCP_QUICKACK option.
+func (so *SocketOptions) GetQuickAck() bool {
+ return atomic.LoadUint32(&so.quickAckEnabled) != 0
+}
+
+// SetQuickAck sets value for TCP_QUICKACK option.
+func (so *SocketOptions) SetQuickAck(v bool) {
+ storeAtomicBool(&so.quickAckEnabled, v)
+}
+
+// GetDelayOption gets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) GetDelayOption() bool {
+ return atomic.LoadUint32(&so.delayOptionEnabled) != 0
+}
+
+// SetDelayOption sets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) SetDelayOption(v bool) {
+ storeAtomicBool(&so.delayOptionEnabled, v)
+ so.handler.OnDelayOptionSet(v)
+}
+
+// GetCorkOption gets value for TCP_CORK option.
+func (so *SocketOptions) GetCorkOption() bool {
+ return atomic.LoadUint32(&so.corkOptionEnabled) != 0
+}
+
+// SetCorkOption sets value for TCP_CORK option.
+func (so *SocketOptions) SetCorkOption(v bool) {
+ storeAtomicBool(&so.corkOptionEnabled, v)
+ so.handler.OnCorkOptionSet(v)
+}
+
+// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) GetReceiveOriginalDstAddress() bool {
+ return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0
+}
+
+// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
+ storeAtomicBool(&so.receiveOriginalDstAddress, v)
+}
+
+// GetRecvError gets value for IP*_RECVERR option.
+func (so *SocketOptions) GetRecvError() bool {
+ return atomic.LoadUint32(&so.recvErrEnabled) != 0
+}
+
+// SetRecvError sets value for IP*_RECVERR option.
+func (so *SocketOptions) SetRecvError(v bool) {
+ storeAtomicBool(&so.recvErrEnabled, v)
+ if !v {
+ so.pruneErrQueue()
+ }
+}
+
+// GetLastError gets value for SO_ERROR option.
+func (so *SocketOptions) GetLastError() *Error {
+ return so.handler.LastError()
+}
+
+// GetOutOfBandInline gets value for SO_OOBINLINE option.
+func (*SocketOptions) GetOutOfBandInline() bool {
+ return true
+}
+
+// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not
+// support disabling this option.
+func (*SocketOptions) SetOutOfBandInline(bool) {}
+
+// GetLinger gets value for SO_LINGER option.
+func (so *SocketOptions) GetLinger() LingerOption {
+ so.mu.Lock()
+ linger := so.linger
+ so.mu.Unlock()
+ return linger
+}
+
+// SetLinger sets value for SO_LINGER option.
+func (so *SocketOptions) SetLinger(linger LingerOption) {
+ so.mu.Lock()
+ so.linger = linger
+ so.mu.Unlock()
+}
+
+// SockErrOrigin represents the constants for error origin.
+type SockErrOrigin uint8
+
+const (
+ // SockExtErrorOriginNone represents an unknown error origin.
+ SockExtErrorOriginNone SockErrOrigin = iota
+
+ // SockExtErrorOriginLocal indicates a local error.
+ SockExtErrorOriginLocal
+
+ // SockExtErrorOriginICMP indicates an IPv4 ICMP error.
+ SockExtErrorOriginICMP
+
+ // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error.
+ SockExtErrorOriginICMP6
+)
+
+// IsICMPErr indicates if the error originated from an ICMP error.
+func (origin SockErrOrigin) IsICMPErr() bool {
+ return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6
+}
+
+// SockError represents a queue entry in the per-socket error queue.
+//
+// +stateify savable
+type SockError struct {
+ sockErrorEntry
+
+ // Err is the error caused by the errant packet.
+ Err *Error
+ // ErrOrigin indicates the error origin.
+ ErrOrigin SockErrOrigin
+ // ErrType is the type in the ICMP header.
+ ErrType uint8
+ // ErrCode is the code in the ICMP header.
+ ErrCode uint8
+ // ErrInfo is additional info about the error.
+ ErrInfo uint32
+
+ // Payload is the errant packet's payload.
+ Payload []byte
+ // Dst is the original destination address of the errant packet.
+ Dst FullAddress
+ // Offender is the original sender address of the errant packet.
+ Offender FullAddress
+ // NetProto is the network protocol being used to transmit the packet.
+ NetProto NetworkProtocolNumber
+}
+
+// pruneErrQueue resets the queue.
+func (so *SocketOptions) pruneErrQueue() {
+ so.errQueueMu.Lock()
+ so.errQueue.Reset()
+ so.errQueueMu.Unlock()
+}
+
+// DequeueErr dequeues a socket extended error from the error queue and returns
+// it. Returns nil if queue is empty.
+func (so *SocketOptions) DequeueErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+
+ err := so.errQueue.Front()
+ if err != nil {
+ so.errQueue.Remove(err)
+ }
+ return err
+}
+
+// PeekErr returns the error in the front of the error queue. Returns nil if
+// the error queue is empty.
+func (so *SocketOptions) PeekErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ return so.errQueue.Front()
+}
+
+// QueueErr inserts the error at the back of the error queue.
+//
+// Preconditions: so.GetRecvError() == true.
+func (so *SocketOptions) QueueErr(err *SockError) {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ so.errQueue.PushBack(err)
+}
+
+// QueueLocalErr queues a local error onto the local queue.
+func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
+ so.QueueErr(&SockError{
+ Err: err,
+ ErrOrigin: SockExtErrorOriginLocal,
+ ErrInfo: info,
+ Payload: payload,
+ Dst: dst,
+ NetProto: net,
+ })
+}
+
+// GetBindToDevice gets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) GetBindToDevice() int32 {
+ return atomic.LoadInt32(&so.bindToDevice)
+}
+
+// SetBindToDevice sets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error {
+ if !so.handler.HasNIC(bindToDevice) {
+ return ErrUnknownDevice
+ }
+
+ atomic.StoreInt32(&so.bindToDevice, bindToDevice)
+ return nil
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index d09ebe7fa..bb30556cf 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -112,7 +112,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
- shard_count = 20,
+ shard_count = most_shards,
deps = [
":stack",
"//pkg/rand",
@@ -120,6 +120,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
@@ -131,7 +132,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
@@ -148,7 +148,6 @@ go_test(
],
library = ":stack",
deps = [
- "//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index adeebfe37..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
-}
-
-// ReadOnlyAddressableEndpointState provides read-only access to an
-// AddressableEndpointState.
-type ReadOnlyAddressableEndpointState struct {
- inner *AddressableEndpointState
}
-// AddrOrMatching returns an endpoint for the passed address that is consisdered
-// bound to the wrapped AddressableEndpointState.
-//
-// If addr is an exact match with an existing address, that address is returned.
-// Otherwise, f is called with each address and the address that f returns true
-// for is returned.
+// GetAddress returns the AddressEndpoint for the passed address.
//
-// Returns nil of no address matches.
-func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
-
- if ep, ok := m.inner.mu.endpoints[addr]; ok {
- if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
- return ep
- }
- }
-
- for _, ep := range m.inner.mu.endpoints {
- if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
- return ep
- }
- }
-
- return nil
-}
-
-// Lookup returns the AddressEndpoint for the passed address.
+// GetAddress does not increment the address's reference count or check if the
+// address is considered bound to the endpoint.
//
-// Returns nil if the passed address is not associated with the
-// AddressableEndpointState.
-func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Returns nil if the passed address is not associated with the endpoint.
+func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- ep, ok := m.inner.mu.endpoints[addr]
+ ep, ok := a.mu.endpoints[addr]
if !ok {
return nil
}
return ep
}
-// ForEach calls f for each address pair.
+// ForEachEndpoint calls f for each address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- for _, ep := range m.inner.mu.endpoints {
+ for _, ep := range a.mu.endpoints {
if !f(ep) {
return
}
@@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool)
// ForEachPrimaryEndpoint calls f for each primary address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
- for _, ep := range m.inner.mu.primary {
- f(ep)
- }
-}
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
-// ReadOnly returns a readonly reference to a.
-func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
- return ReadOnlyAddressableEndpointState{inner: a}
+ for _, ep := range a.mu.primary {
+ if !f(ep) {
+ return
+ }
+ }
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
@@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
return deprecatedEndpoint
}
-// AcquireAssignedAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+// AcquireAssignedAddressOrMatching returns an address endpoint that is
+// considered assigned to the addressable endpoint.
+//
+// If the address is an exact match with an existing address, that address is
+// returned. Otherwise, if f is provided, f is called with each address and
+// the address that f returns true for is returned.
+//
+// If there is no matching address, a temporary address will be returned if
+// allowTemp is true.
+//
+// Regardless how the address was obtained, it will be acquired before it is
+// returned.
+func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
@@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return addrState
}
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
+ }
+ }
+
if !allowTemp {
return nil
}
@@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB)
+}
+
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
@@ -588,50 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index cb7dec1ea..93e8e1c51 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -309,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
@@ -560,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) {
}
}
+func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 50 * time.Millisecond,
+ onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) {
+ // Don't resolve the link address.
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */)
+
+ const numPackets int = 5
+ // These packets will all be enqueued in the packet queue to wait for link
+ // address resolution.
+ for i := 0; i < numPackets; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
+
+ // All packets should fail resolution.
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
+ for i := 0; i < numPackets; i++ {
+ select {
+ case got := <-ep2.C:
+ t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+}
+
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index c9b13cd0e..792f4f170 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -18,7 +18,6 @@ import (
"fmt"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -58,9 +57,6 @@ const (
incomplete entryState = iota
// ready means that the address has been resolved and can be used.
ready
- // failed means that address resolution timed out and the address
- // could not be resolved.
- failed
)
// String implements Stringer.
@@ -70,8 +66,6 @@ func (s entryState) String() string {
return "incomplete"
case ready:
return "ready"
- case failed:
- return "failed"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -80,40 +74,48 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ // linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
+ // TODO(gvisor.dev/issue/5150): move these fields under mu.
+ // mu protects the fields below.
+ mu sync.RWMutex
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
- // wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of incomplete these waiters are notified.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil iff
- // s is incomplete and resolution is not yet in progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
}
-// changeState sets the entry's state to ns, notifying any waiters.
+func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
+ for _, callback := range e.onResolve {
+ callback(linkAddr, len(linkAddr) != 0)
+ }
+ e.onResolve = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// changeStateLocked sets the entry's state to ns.
//
// The entry's expiration is bumped up to the greater of itself and the passed
// expiration; the zero value indicates immediate expiration, and is set
// unconditionally - this is an implementation detail that allows for entries
// to be reused.
-func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
- // Notify whoever is waiting on address resolution when transitioning
- // out of incomplete.
- if e.s == incomplete && ns != incomplete {
- for w := range e.wakers {
- w.Assert()
- }
- e.wakers = nil
- if ch := e.done; ch != nil {
- close(ch)
- }
- e.done = nil
+//
+// Precondition: e.mu must be locked
+func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
+ if e.s == incomplete && ns == ready {
+ e.notifyCompletionLocked(e.linkAddr)
}
if expiration.IsZero() || expiration.After(e.expiration) {
@@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
e.s = ns
}
-func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
- delete(e.wakers, w)
-}
-
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
@@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
c.cache.Lock()
entry := c.getOrCreateEntryLocked(k)
- entry.linkAddr = v
-
- entry.changeState(ready, expiration)
c.cache.Unlock()
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+ entry.linkAddr = v
+ entry.changeStateLocked(ready, expiration)
}
// getOrCreateEntryLocked retrieves a cache entry associated with k. The
@@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
var entry *linkAddrEntry
if len(c.cache.table) == linkAddrCacheSize {
entry = c.cache.lru.Back()
+ entry.mu.Lock()
delete(c.cache.table, entry.addr)
c.cache.lru.Remove(entry)
- // Wake waiters and mark the soon-to-be-reused entry as expired. Note
- // that the state passed doesn't matter when the zero time is passed.
- entry.changeState(failed, time.Time{})
+ // Wake waiters and mark the soon-to-be-reused entry as expired.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ entry.mu.Unlock()
} else {
entry = new(linkAddrEntry)
}
@@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ if onResolve != nil {
+ onResolve(addr, true)
+ }
return addr, nil, nil
}
}
@@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
+ case ready:
if !time.Now().After(entry.expiration) {
// Not expired.
- switch s {
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ if onResolve != nil {
+ onResolve(entry.linkAddr, true)
}
+ return entry.linkAddr, nil, nil
}
- entry.changeState(incomplete, time.Time{})
+ entry.changeStateLocked(incomplete, time.Time{})
fallthrough
case incomplete:
- if waker != nil {
- if entry.wakers == nil {
- entry.wakers = make(map[*sleep.Waker]struct{})
- }
- entry.wakers[waker] = struct{}{}
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
}
-
if entry.done == nil {
- // Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- }
-
entry.done = make(chan struct{})
go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
-
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker previously added through get().
-func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.cache.Lock()
- defer c.cache.Unlock()
-
- if entry, ok := c.cache.table[k]; ok {
- entry.removeWaker(waker)
- }
-}
-
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
@@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
}
}
-// checkLinkRequest checks whether previous attempt to resolve address has succeeded
-// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
-// can stop, false if another request should be sent.
+// checkLinkRequest checks whether previous attempt to resolve address has
+// succeeded and mark the entry accordingly. Returns true if request can stop,
+// false if another request should be sent.
func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
c.cache.Lock()
defer c.cache.Unlock()
@@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
// Entry was evicted from the cache.
return true
}
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
- // Entry was made ready by resolver or failed. Either way we're done.
+ case ready:
+ // Entry was made ready by resolver.
case incomplete:
if attempt+1 < c.resolutionAttempts {
// No response yet, need to send another ARP request.
return false
}
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed, now.Add(c.ageLimit))
+ // Max number of retries reached, delete entry.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ delete(c.cache.table, k)
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index d2e37f38d..6883045b5 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -21,7 +21,6 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -50,6 +49,7 @@ type testLinkAddressResolver struct {
}
func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
@@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
}
func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
+ var attemptedResolution bool
for {
- if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- return got, err
+ got, ch, err := c.get(addr, linkRes, "", nil, nil)
+ if err == tcpip.ErrWouldBlock {
+ if attemptedResolution {
+ return got, tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
+ <-ch
+ continue
}
- s.Fetch(true)
+ return got, err
}
}
@@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
+ c.cache.Lock()
+ defer c.cache.Unlock()
for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
e := testAddrs[i]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
}
func TestCacheConcurrent(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
for r := 0; r < 16; r++ {
@@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) {
go func() {
for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
- c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
wg.Done()
}()
@@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) {
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) {
}
e = testAddrs[0]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
+
e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err)
}
}
@@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}
-
-// TestCacheWaker verifies that RemoveWaker removes a waker previously added
-// through get().
-func TestCacheWaker(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
-
- // First, sanity check that wakers are working.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 1
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[0]
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
- id, ok := s.Fetch(true /* block */)
- if !ok {
- t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)")
- }
- if id != wakerID {
- t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID)
- }
-
- if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
- }
-
- // Check that RemoveWaker works.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 2 // different than the ID used in the sanity check
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[1]
- linkRes.onLinkAddressRequest = func() {
- // Remove the waker before the linkAddrCache has the opportunity to send
- // a notification.
- c.removeWaker(e.addr, &w)
- }
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
-
- if got, err := getBlocking(c, e.addr, linkRes); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Fatalf("unexpected notification from waker with id %d", id)
- }
- }
-}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 7b0cd58f6..61636cae5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) {
}
// We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) {
}
// Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
@@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
}
@@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
@@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
@@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate an address conflict.
test.rxPkt(e, addr1)
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
@@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) {
}
// Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
t.Errorf("got NeighborSolicit = %d, want <= 1", got)
}
})
@@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
+ raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
binary.BigEndian.PutUint16(raPayload[2:], rl)
@@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
return stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
@@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
}
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
}
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
- remaining--
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
}
+ }
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
- waitForPkt(defaultAsyncPositiveEventTimeout)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt)
}
+ }
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
+
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
}
func TestStopStartSolicitingRouters(t *testing.T) {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 0d3f626cf..c15f10e76 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -99,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
n.dynamic.lru.Remove(e)
n.dynamic.count--
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Unknown)
- e.notifyWakersLocked()
+ e.removeLocked()
e.mu.Unlock()
}
n.cache[remoteAddr] = entry
@@ -110,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
return entry
}
-// entry looks up the neighbor cache for translating address to link address
-// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
-// is a LinkAddressResolver registered with the network protocol, the cache
-// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
-// provided, it will be notified when address resolution is complete (success
-// or not).
+// entry looks up neighbor information matching the remote address, and returns
+// it if readily available.
+//
+// Returns ErrWouldBlock if the link address is not readily available, along
+// with a notification channel for the caller to block on. Triggers address
+// resolution asynchronously.
+//
+// If onResolve is provided, it will be called either immediately, if resolution
+// is not required, or when address resolution is complete, with the resolved
+// link address and whether resolution succeeded. After any callbacks have been
+// called, the returned notification channel is closed.
+//
+// NB: if a callback is provided, it should not call into the neighbor cache.
//
// If specified, the local address must be an address local to the interface the
// neighbor cache belongs to. The local address is the source address of a
// packet prompting NUD/link address resolution.
//
-// If address resolution is required, ErrNoLinkAddress and a notification
-// channel is returned for the top level caller to block. Channel is closed
-// once address resolution is complete (success or not).
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve.
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
Addr: remoteAddr,
@@ -132,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
State: Static,
UpdatedAtNanos: 0,
}
+ if onResolve != nil {
+ onResolve(linkAddr, true)
+ }
return e, nil, nil
}
@@ -149,47 +155,36 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// of packets to a neighbor. While reasserting a neighbor's reachability,
// a node continues sending packets to that neighbor using the cached
// link-layer address."
+ if onResolve != nil {
+ onResolve(entry.neigh.LinkAddr, true)
+ }
return entry.neigh, nil, nil
- case Unknown, Incomplete:
- entry.addWakerLocked(w)
-
+ case Unknown, Incomplete, Failed:
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
+ }
if entry.done == nil {
// Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
- }
entry.done = make(chan struct{})
}
-
entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
- case Failed:
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
- n.mu.Lock()
- if entry, ok := n.cache[addr]; ok {
- delete(entry.wakers, waker)
- }
- n.mu.Unlock()
-}
-
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(n.cache))
n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ entries := make([]NeighborEntry, 0, len(n.cache))
for _, entry := range n.cache {
entry.mu.RLock()
entries = append(entries, entry.neigh)
entry.mu.RUnlock()
}
- n.mu.RUnlock()
return entries
}
@@ -221,34 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
return
}
- // Notify that resolution has been interrupted, just in case the entry was
- // in the Incomplete or Probe state.
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
-// removeEntryLocked removes the specified entry from the neighbor cache.
-//
-// Prerequisite: n.mu and entry.mu MUST be locked.
-func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
- if entry.neigh.State != Static {
- n.dynamic.lru.Remove(entry)
- n.dynamic.count--
- }
- if entry.neigh.State != Failed {
- entry.dispatchRemoveEventLocked()
- }
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
-
- delete(n.cache, entry.neigh.Addr)
-}
-
// removeEntry removes a dynamic or static entry by address from the neighbor
// cache. Returns true if the entry was found and deleted.
func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
@@ -263,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
entry.mu.Lock()
defer entry.mu.Unlock()
- n.removeEntryLocked(entry)
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+
+ entry.removeLocked()
+ delete(n.cache, entry.neigh.Addr)
return true
}
@@ -274,9 +254,7 @@ func (n *neighborCache) clear() {
for _, entry := range n.cache {
entry.mu.Lock()
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 732a299f7..a2ed6ae2a 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -28,7 +28,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
)
@@ -190,15 +189,18 @@ type testNeighborResolver struct {
entries *testEntryStore
delay time.Duration
onLinkAddressRequest func()
+ dropReplies bool
}
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
- // Delay handling the request to emulate network latency.
- r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(targetAddr)
- })
+ if !r.dropReplies {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(targetAddr)
+ })
+ }
// Execute post address resolution action, if available.
if f := r.onLinkAddressRequest; f != nil {
@@ -291,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -327,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -354,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -413,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -461,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -513,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
}
// Expect to find only the most recent entries. The order of entries reported
- // by entries() is undeterministic, so entries have to be sorted before
+ // by entries() is nondeterministic, so entries have to be sorted before
// comparison.
wantUnsortedEntries := opts.wantStaticEntries
for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
@@ -575,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
@@ -650,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -694,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -756,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -826,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -907,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
}
}
-func TestNeighborCacheNotifiesWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- id, ok := s.Fetch(false /* block */)
- if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
- }
- if id != wakerID {
- t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
-func TestNeighborCacheRemoveWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
-
- // Remove the waker before the neighbor cache has the opportunity to send a
- // notification.
- neigh.removeWaker(entry.Addr, &w)
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Errorf("unexpected notification from waker with id %d", id)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
config := DefaultNUDConfigurations()
// Stay in Reachable so the cache can overflow
@@ -1062,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1075,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
@@ -1129,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) {
// Add a dynamic entry.
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1187,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) {
}
}
- // Clear shoud remove both dynamic and static entries.
+ // Clear should remove both dynamic and static entries.
neigh.clear()
// Remove events dispatched from clear() have no deterministic order so they
@@ -1234,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -1318,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
frequentlyUsedEntry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
// The following logic is very similar to overflowCache, but
@@ -1330,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
@@ -1373,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1381,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1435,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
}
// Expect to find only the frequently used entry and the most recent entries.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
@@ -1494,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) {
go func(entry NeighborEntry) {
defer wg.Done()
if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
- // Wait for all gorountines to send a request
+ // Wait for all goroutines to send a request
wg.Wait()
// Process all the requests for a single entry concurrently
@@ -1509,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than can fit in
// the cache. Our eviction strategy requires that the last entries are
// present, up to the size of the neighbor cache, and the rest are missing.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
var wantUnsortedEntries []NeighborEntry
for i := store.size() - neighborCacheSize; i < store.size(); i++ {
@@ -1547,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) {
// Add an entry
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err)
}
if t.Failed() {
t.FailNow()
@@ -1578,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1587,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
entry, ok := store.entry(1)
if !ok {
- t.Fatalf("store.entry(1) not found")
+ t.Fatal("store.entry(1) not found")
}
updatedLinkAddr = entry.LinkAddr
}
@@ -1604,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1612,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
@@ -1622,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) {
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1630,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1654,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
},
}
- // First, sanity check that resolution is working
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ // First, sanity check that resolution is working
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
- clock.Advance(typicalLatency)
+
got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1673,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
- // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ // Verify address resolution fails for an unknown address.
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1714,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+}
+
+// TestNeighborCacheRetryResolution simulates retrying communication after
+// failing to perform address resolution.
+func TestNeighborCacheRetryResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ // Simulate a faulty link.
+ dropReplies: true,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatal("store.entry(0) not found")
+ }
+
+ // Perform address resolution with a faulty link, which will fail.
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+ }
+
+ // Verify the entry is in Failed state.
+ wantEntries := []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Failed,
+ },
+ }
+ if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Retry address resolution with a working link.
+ linkRes.dropReplies = false
+ {
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ if incompleteEntry.State != Incomplete {
+ t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
+ }
+ clock.Advance(typicalLatency)
+
+ select {
+ case <-ch:
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err)
+ }
+ if reachableEntry.Addr != entry.Addr {
+ t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr)
+ }
+ if reachableEntry.LinkAddr != entry.LinkAddr {
+ t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr)
+ }
+ if reachableEntry.State != Reachable {
+ t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String())
+ }
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
}
@@ -1742,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
Addr: testEntryBroadcastAddr,
@@ -1750,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1775,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ b.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
- if doneCh != nil {
- <-doneCh
+
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 32399b4f5..75afb3001 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -19,7 +19,6 @@ import (
"sync"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -67,8 +66,7 @@ const (
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
- // Failed means traffic should not be sent to this neighbor since attempts of
- // reachability have returned inconclusive.
+ // Failed means recent attempts of reachability have returned inconclusive.
Failed
)
@@ -93,16 +91,13 @@ type neighborEntry struct {
neigh NeighborEntry
- // wakers is a set of waiters for address resolution result. Anytime state
- // transitions out of incomplete these waiters are notified. It is nil iff
- // address resolution is ongoing and no clients are waiting for the result.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil
- // iff nudState is not Reachable and address resolution is not yet in
- // progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
+
isRouter bool
job *tcpip.Job
}
@@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
}
}
-// addWaker adds w to the list of wakers waiting for address resolution.
-// Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
- if w == nil {
- return
- }
- if e.wakers == nil {
- e.wakers = make(map[*sleep.Waker]struct{})
- }
- e.wakers[w] = struct{}{}
-}
-
-// notifyWakersLocked notifies those waiting for address resolution, whether it
-// succeeded or failed. Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) notifyWakersLocked() {
- for w := range e.wakers {
- w.Assert()
+// notifyCompletionLocked notifies those waiting for address resolution, with
+// the link address if resolution completed successfully.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ for _, callback := range e.onResolve {
+ callback(e.neigh.LinkAddr, succeeded)
}
- e.wakers = nil
+ e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
@@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
@@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() {
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
@@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() {
// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
// has been removed.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
+// cancelJobLocked cancels the currently scheduled action, if there is one.
+// Entries in Unknown, Stale, or Static state do not have a scheduled action.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) cancelJobLocked() {
+ if job := e.job; job != nil {
+ job.Cancel()
+ }
+}
+
+// removeLocked prepares the entry for removal.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) removeLocked() {
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.dispatchRemoveEventLocked()
+ e.cancelJobLocked()
+ e.notifyCompletionLocked(false /* succeeded */)
+}
+
// setStateLocked transitions the entry to the specified state immediately.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
-// e.mu MUST be locked.
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
- // Cancel the previously scheduled action, if there is one. Entries in
- // Unknown, Stale, or Static state do not have scheduled actions.
- if timer := e.job; timer != nil {
- timer.Cancel()
- }
+ e.cancelJobLocked()
prev := e.neigh.State
e.neigh.State = next
@@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
e.job.Schedule(immediateDuration)
case Failed:
- e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() {
- e.nic.neigh.removeEntryLocked(e)
- })
- e.job.Schedule(config.UnreachableTime)
+ e.notifyCompletionLocked(false /* succeeded */)
case Unknown, Stale, Static:
// Do nothing
@@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
+
+ fallthrough
case Unknown:
e.neigh.State = Incomplete
e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
@@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// implementation may find it convenient in some cases to return errors
// to the sender by taking the offending packet, generating an ICMP
// error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
+ // error-handling routines." - RFC 4861 section 2.1
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
@@ -349,8 +358,6 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
- case Failed:
- e.nic.stats.Neighbor.FailedEntryLookups.Increment()
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -360,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// Neighbor Solicitation for ARP or NDP, respectively).
//
// Follows the logic defined in RFC 4861 section 7.2.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// Probes MUST be silently discarded if the target address is tentative, does
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
switch e.neigh.State {
- case Unknown, Incomplete, Failed:
+ case Unknown, Failed:
e.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
- e.notifyWakersLocked()
e.dispatchAddEventLocked()
+ case Incomplete:
+ // "If an entry already exists, and the cached link-layer address
+ // differs from the one in the received Source Link-Layer option, the
+ // cached address should be replaced by the received address, and the
+ // entry's reachability state MUST be set to STALE."
+ // - RFC 4861 section 7.2.3
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.setStateLocked(Stale)
+ e.notifyCompletionLocked(true /* succeeded */)
+ e.dispatchChangeEventLocked()
+
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
@@ -404,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// not be possible. SEND uses RSA key pairs to produce Cryptographically
// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
// claimed source of an NDP message is the owner of the claimed address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
switch e.neigh.State {
case Incomplete:
@@ -422,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
@@ -457,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
if !wasReachable {
e.dispatchChangeEventLocked()
}
@@ -495,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
// (e.g. TCP acknowledgements) reachability confirmation.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
@@ -512,23 +535,3 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
}
-
-// doubleLock combines two locks into one while maintaining lock ordering.
-//
-// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed
-// neighbor is allowed.
-type doubleLock struct {
- first, second sync.Locker
-}
-
-// Lock locks both locks in order: first then second.
-func (l *doubleLock) Lock() {
- l.first.Lock()
- l.second.Lock()
-}
-
-// Unlock unlocks both locks in reverse order: second then first.
-func (l *doubleLock) Unlock() {
- l.second.Unlock()
- l.first.Unlock()
-}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index c497d3932..ec34ffa5a 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -25,7 +25,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -73,36 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option {
// The following unit tests exercise every state transition and verify its
// behavior with RFC 4681.
//
-// | From | To | Cause | Action | Event |
-// | ========== | ========== | ========================================== | =============== | ======= |
-// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
-// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
-// | Unknown | Stale | Probe w/ unknown address | | Added |
-// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
-// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
-// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
-// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
-// | Reachable | Stale | Reachable timer expired | | Changed |
-// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
-// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
-// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet queued | | Changed |
-// | Delay | Reachable | Upper-layer confirmation | | Changed |
-// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
-// | Delay | Probe | Delay timer expired | Send probe | Changed |
-// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
-// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
-// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
-// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Failed | Failed | Packet queued | | |
-// | Failed | | Unreachability timer expired | Delete entry | |
+// | From | To | Cause | Update | Action | Event |
+// | ========== | ========== | ========================================== | ======== | ===========| ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added |
+// | Unknown | Stale | Probe | | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed |
+// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | |
+// | Reachable | Stale | Reachable timer expired | | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Stale | Stale | Override confirmation | LinkAddr | | Changed |
+// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed |
+// | Stale | Delay | Packet sent | | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | | Changed |
+// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Delay | Probe | Delay timer expired | | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed |
+// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Probe | Probe | Retransmit timer expired | | | Changed |
+// | Probe | Failed | Max probes sent without reply | | Notify | Removed |
+// | Failed | Incomplete | Packet queued | | Send probe | Added |
type testEntryEventType uint8
@@ -258,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -291,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -320,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -367,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -406,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
updatedAtNanos := e.neigh.UpdatedAtNanos
e.mu.Unlock()
@@ -560,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) {
nudDisp.mu.Unlock()
}
-// TestEntryAddsAndClearsWakers verifies that wakers are added when
-// addWakerLocked is called and cleared when address resolution finishes. In
-// this case, address resolution will finish when transitioning from Incomplete
-// to Reachable.
-func TestEntryAddsAndClearsWakers(t *testing.T) {
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ }
e.mu.Unlock()
runImmediatelyScheduledJobs(clock)
@@ -593,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Lock()
- if got := e.wakers; got != nil {
- t.Errorf("got e.wakers = %v, want = nil", got)
- }
- e.addWakerLocked(&w)
- if got, want := w.IsAsserted(), false; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
- }
- if e.wakers == nil {
- t.Error("expected e.wakers to be non-nil")
- }
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
- IsRouter: false,
+ IsRouter: true,
})
- if e.wakers != nil {
- t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
- if got, want := w.IsAsserted(), true; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
}
e.mu.Unlock()
@@ -643,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -663,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- linkRes.mu.Unlock()
e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
+ Solicited: false,
Override: false,
- IsRouter: true,
+ IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
- }
- if !e.isRouter {
- t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -698,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
Entry: NeighborEntry{
Addr: entryTestAddr1,
LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ State: Stale,
},
},
}
@@ -709,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToStale(t *testing.T) {
+func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -736,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) {
}
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
+ e.handleProbeLocked(entryTestLinkAddr1)
if e.neigh.State != Stale {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
@@ -780,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -841,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
e.mu.Unlock()
}
@@ -885,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
Override: false,
IsRouter: true,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.isRouter, true; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
@@ -932,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
}
@@ -1083,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
}
@@ -2381,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
e.mu.Unlock()
@@ -2447,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.mu.Unlock()
}
@@ -2505,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2620,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2740,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2836,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2964,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -3101,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -3435,212 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryFailedToFailed(t *testing.T) {
+func TestEntryFailedToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
- c.MaxUnicastProbes = 3
e, nudDisp, linkRes, clock := entryTestSetup(c)
- // Verify the cache contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
- t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
- }
-
// TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
// their expected state.
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.mu.Unlock()
-
- runImmediatelyScheduledJobs(clock)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
clock.Advance(waitFor)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
- }
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- },
- },
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- },
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
},
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
- },
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
},
{
- EventType: entryTestRemoved,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
- },
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
},
}
- nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-
- failedLookups := e.nic.stats.Neighbor.FailedEntryLookups
- if got := failedLookups.Value(); got != 0 {
- t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got)
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
e.mu.Lock()
- // Verify queuing a packet to the entry immediately fails.
- e.handlePacketQueuedLocked(entryTestAddr2)
- state := e.neigh.State
- e.mu.Unlock()
- if state != Failed {
- t.Errorf("got e.neigh.State = %q, want = %q", state, Failed)
- }
-
- if got := failedLookups.Value(); got != 1 {
- t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got)
- }
-}
-
-func TestEntryFailedGetsDeleted(t *testing.T) {
- c := DefaultNUDConfigurations()
- c.MaxMulticastProbes = 3
- c.MaxUnicastProbes = 3
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- // Verify the cache contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
- t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
-
- e.mu.Lock()
- e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- runImmediatelyScheduledJobs(clock)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
- }
-
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
e.handlePacketQueuedLocked(entryTestAddr2)
- e.mu.Unlock()
-
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
- clock.Advance(waitFor)
- {
- wantProbes := []entryTestProbeInfo{
- // The next three probe are sent in Probe.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
+ e.mu.Unlock()
wantEvents := []testEntryEventInfo{
{
@@ -3653,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
},
},
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- },
- },
- {
- EventType: entryTestChanged,
+ EventType: entryTestRemoved,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
}
@@ -3694,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- // Verify the cache no longer contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
- t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
- }
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 43696ba14..4a34805b5 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -20,7 +20,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -54,9 +53,9 @@ type NIC struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ // packetEPs is protected by mu, but the contained packetEndpointList are
+ // not.
+ packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
@@ -82,6 +81,39 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
+type packetEndpointList struct {
+ mu sync.RWMutex
+
+ // eps is protected by mu, but the contained PacketEndpoint values are not.
+ eps []PacketEndpoint
+}
+
+func (p *packetEndpointList) add(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.eps = append(p.eps, ep)
+}
+
+func (p *packetEndpointList) remove(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for i, epOther := range p.eps {
+ if epOther == ep {
+ p.eps = append(p.eps[:i], p.eps[i+1:]...)
+ break
+ }
+ }
+}
+
+// forEach calls fn with each endpoints in p while holding the read lock on p.
+func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ for _, ep := range p.eps {
+ fn(ep)
+ }
+}
+
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -102,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -125,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = new(packetEndpointList)
}
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
+ nic.mu.packetEPs[netNum] = new(packetEndpointList)
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
@@ -172,7 +204,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -184,6 +216,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.
@@ -258,16 +294,18 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// the same unresolved IP address, and transmit the saved
// packet when the address has been resolved.
//
- // RFC 4861 section 5.2 (for IPv6):
- // Once the IP address of the next-hop node is known, the sender
- // examines the Neighbor Cache for link-layer information about that
- // neighbor. If no entry exists, the sender creates one, sets its state
- // to INCOMPLETE, initiates Address Resolution, and then queues the data
- // packet pending completion of address resolution.
+ // RFC 4861 section 7.2.2 (for IPv6):
+ // While waiting for address resolution to complete, the sender MUST, for
+ // each neighbor, retain a small queue of packets waiting for address
+ // resolution to complete. The queue MUST hold at least one packet, and MAY
+ // contain more. However, the number of queued packets per neighbor SHOULD
+ // be limited to some small value. When a queue overflows, the new arrival
+ // SHOULD replace the oldest entry. Once address resolution completes, the
+ // node transmits any queued packets.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
- r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ r.Acquire()
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
@@ -279,9 +317,11 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ routeInfo: routeInfo{
+ NetProto: protocol,
+ },
}
+ r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
}
@@ -508,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
return n.neigh.entries(), nil
}
-func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
- if n.neigh == nil {
- return
- }
-
- n.neigh.removeWaker(addr, w)
-}
-
func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
if n.neigh == nil {
return tcpip.ErrNotSupported
@@ -563,8 +595,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -580,11 +611,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
@@ -639,15 +666,23 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
- packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet type sockets that may be listening for all protocols.
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ protoEPs := n.mu.packetEPs[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ // Deliver to interested packet endpoints without holding NIC lock.
+ deliverPacketEPs := func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketHost
ep.HandlePacket(n.id, local, protocol, p)
}
+ if protoEPs != nil {
+ protoEPs.forEach(deliverPacketEPs)
+ }
+ if anyEPs != nil {
+ anyEPs.forEach(deliverPacketEPs)
+ }
// Parse headers.
netProto := n.stack.NetworkProtocolInstance(protocol)
@@ -688,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ eps.forEach(func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
- }
+ })
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -850,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
if !ok {
return tcpip.ErrNotSupported
}
- n.mu.packetEPs[netProto] = append(eps, ep)
+ eps.add(ep)
return nil
}
@@ -863,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
if !ok {
return
}
-
- for i, epOther := range eps {
- if epOther == ep {
- n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
- return
- }
- }
+ eps.remove(ep)
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index ab629b3a4..12d67409a 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -109,14 +109,6 @@ const (
//
// Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
defaultMaxReachbilityConfirmations = 3
-
- // defaultUnreachableTime is the default duration for how long an entry will
- // remain in the FAILED state before being removed from the neighbor cache.
- //
- // Note, there is no equivalent protocol constant defined in RFC 4861. It
- // leaves the specifics of any garbage collection mechanism up to the
- // implementation.
- defaultUnreachableTime = 5 * time.Second
)
// NUDDispatcher is the interface integrators of netstack must implement to
@@ -278,10 +270,6 @@ type NUDConfigurations struct {
// TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
// configuration option is necessary.
MaxReachabilityConfirmations uint32
-
- // UnreachableTime describes how long an entry will remain in the FAILED
- // state before being removed from the neighbor cache.
- UnreachableTime time.Duration
}
// DefaultNUDConfigurations returns a NUDConfigurations populated with default
@@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations {
MaxUnicastProbes: defaultMaxUnicastProbes,
MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
- UnreachableTime: defaultUnreachableTime,
}
}
@@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() {
if c.MaxUnicastProbes == 0 {
c.MaxUnicastProbes = defaultMaxUnicastProbes
}
- if c.UnreachableTime == 0 {
- c.UnreachableTime = defaultUnreachableTime
- }
}
// calcMaxRandomFactor calculates the maximum value of the random factor used
@@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration {
s.config.BaseReachableTime != s.prevBaseReachableTime ||
s.config.MinRandomFactor != s.prevMinRandomFactor ||
s.config.MaxRandomFactor != s.prevMaxRandomFactor {
- return s.recomputeReachableTimeLocked()
+ s.recomputeReachableTimeLocked()
}
return s.reachableTime
}
@@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration {
// random value gets re-computed at least once every few hours.
//
// s.mu MUST be locked for writing.
-func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+func (s *NUDState) recomputeReachableTimeLocked() {
s.prevBaseReachableTime = s.config.BaseReachableTime
s.prevMinRandomFactor = s.config.MinRandomFactor
s.prevMaxRandomFactor = s.config.MaxRandomFactor
@@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
}
s.expiration = time.Now().Add(2 * time.Hour)
- return s.reachableTime
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 8cffb9fc6..7bca1373e 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -37,7 +37,6 @@ const (
defaultMaxUnicastProbes = 3
defaultMaxAnycastDelayTime = time.Second
defaultMaxReachbilityConfirmations = 3
- defaultUnreachableTime = 5 * time.Second
defaultFakeRandomNum = 0.5
)
@@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
}
}
-func TestNUDConfigurationsUnreachableTime(t *testing.T) {
- tests := []struct {
- name string
- unreachableTime time.Duration
- want time.Duration
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- unreachableTime: 0,
- want: defaultUnreachableTime,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- unreachableTime: time.Millisecond,
- want: time.Millisecond,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.UnreachableTime = test.unreachableTime
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- UseNeighborCache: true,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
- }
- if got := sc.UnreachableTime; got != test.want {
- t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
// TestNUDStateReachableTime verifies the correctness of the ReachableTime
// computation.
func TestNUDStateReachableTime(t *testing.T) {
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 5d364a2b0..4a3adcf33 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
for _, p := range packets {
if cancelled {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
- } else if _, err := p.route.Resolve(nil); err != nil {
+ } else if p.route.IsResolutionRequired() {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 43ca03ada..7e83b7fbb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -259,15 +258,6 @@ const (
PacketLoop
)
-// NetOptions is an interface that allows us to pass network protocol specific
-// options through the Stack layer code.
-type NetOptions interface {
- // SizeWithPadding returns the amount of memory that must be allocated to
- // hold the options given that the value must be rounded up to the next
- // multiple of 4 bytes.
- SizeWithPadding() int
-}
-
// NetworkHeaderParams are the header parameters given as input by the
// transport endpoint to the network.
type NetworkHeaderParams struct {
@@ -279,10 +269,6 @@ type NetworkHeaderParams struct {
// TOS refers to TypeOfService or TrafficClass field of the IP-header.
TOS uint8
-
- // Options is a set of options to add to a network header (or nil).
- // It will be protocol specific opaque information from higher layers.
- Options NetOptions
}
// GroupAddressableEndpoint is an endpoint that supports group addressing.
@@ -291,14 +277,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -816,19 +798,26 @@ type LinkAddressCache interface {
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
- // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
- // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
- // registered with the network protocol, the cache attempts to resolve the address
- // and returns ErrWouldBlock. Waker is notified when address resolution is
- // complete (success or not).
+ // GetLinkAddress finds the link address corresponding to the remote address
+ // (e.g. IP -> MAC).
//
- // If address resolution is required, ErrNoLinkAddress and a notification channel is
- // returned for the top level caller to block. Channel is closed once address resolution
- // is complete (success or not).
- GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
-
- // RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ // Returns a link address for the remote address, if readily available.
+ //
+ // Returns ErrWouldBlock if the link address is not readily available, along
+ // with a notification channel for the caller to block on. Triggers address
+ // resolution asynchronously.
+ //
+ // If onResolve is provided, it will be called either immediately, if
+ // resolution is not required, or when address resolution is complete, with
+ // the resolved link address and whether resolution succeeded. After any
+ // callbacks have been called, the returned notification channel is closed.
+ //
+ // If specified, the local address must be an address local to the interface
+ // the neighbor cache belongs to. The local address is the source address of
+ // a packet prompting NUD/link address resolution.
+ //
+ // TODO(gvisor.dev/issue/5151): Don't return the link address.
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
}
// RawFactory produces endpoints for writing various types of raw packets.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index f0b256507..b0251d0b4 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -17,20 +17,53 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
+ routeInfo
+
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
+
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+
+ // remoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ remoteLinkAddress tcpip.LinkAddress
+ }
+
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
+
+ // linkCache is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkCache LinkAddressCache
+
+ // linkRes is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkRes LinkAddressResolver
+}
+
+type routeInfo struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
- // RemoteLinkAddress is the link-layer (MAC) address of the
- // final destination of the route.
- RemoteLinkAddress tcpip.LinkAddress
-
// LocalAddress is the local address where the route starts.
LocalAddress tcpip.Address
@@ -46,39 +79,38 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+}
- // localAddressNIC is the interface the address is associated with.
- // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
- // address's assigned status without the NIC.
- localAddressNIC *NIC
-
- // localAddressEndpoint is the local address this route is associated with.
- localAddressEndpoint AssignableAddressEndpoint
-
- // outgoingNIC is the interface this route uses to write packets.
- outgoingNIC *NIC
+// RouteInfo contains all of Route's exported fields.
+type RouteInfo struct {
+ routeInfo
- // linkCache is set if link address resolution is enabled for this protocol on
- // the route's NIC.
- linkCache LinkAddressCache
+ // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ RemoteLinkAddress tcpip.LinkAddress
+}
- // linkRes is set if link address resolution is enabled for this protocol on
- // the route's NIC.
- linkRes LinkAddressResolver
+// GetFields returns a RouteInfo with all of r's exported fields. This allows
+// callers to store the route's fields without retaining a reference to it.
+func (r *Route) GetFields() RouteInfo {
+ return RouteInfo{
+ routeInfo: r.routeInfo,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
+ }
}
// constructAndValidateRoute validates and initializes a route. It takes
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
if len(localAddr) == 0 {
localAddr = addressEndpoint.AddressWithPrefix().Address
}
if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
addressEndpoint.DecRef()
- return Route{}
+ return nil
}
// If no remote address is provided, use the local address.
@@ -102,7 +134,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
if len(gateway) > 0 {
r.NextHop = gateway
} else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ r.ResolveWith(header.EthernetBroadcastAddress)
}
return r
@@ -110,7 +142,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
@@ -139,18 +171,23 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
- r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- localAddressEndpoint: localAddressEndpoint,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
+ routeInfo: routeInfo{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ Loop: loop,
+ },
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
}
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
@@ -165,7 +202,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -176,6 +213,14 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
+// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
+// the route.
+func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.mu.remoteLinkAddress
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
return r.outgoingNIC.ID()
@@ -237,22 +282,26 @@ func (r *Route) GSOMaxSize() uint32 {
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
-// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
-// notified when address resolution is complete (success or not).
-//
-// If address resolution is required, ErrNoLinkAddress and a notification channel is
-// returned for the top level caller to block. Channel is closed once address resolution
-// is complete (success or not).
+// Resolve attempts to resolve the link address if necessary.
//
-// The NIC r uses must not be locked.
-func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
- if !r.IsResolutionRequired() {
+// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
+// waiting for ARP reply). If address resolution is required, a notification
+// channel is also returned for the caller to block on. The channel is closed
+// once address resolution is complete (successful or not). If a callback is
+// provided, it will be called when address resolution is complete, regardless
+// of success or failure.
+func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
+ r.mu.Lock()
+
+ if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
+ r.mu.Unlock()
return nil, nil
}
@@ -260,7 +309,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
- r.RemoteLinkAddress = r.LocalLinkAddress
+ r.mu.remoteLinkAddress = r.LocalLinkAddress
+ r.mu.Unlock()
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -273,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
+ // Increment the route's reference count because finishResolution retains a
+ // reference to the route and releases it when called.
+ r.acquireLocked()
+ r.mu.Unlock()
+
+ finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
+ if ok {
+ r.ResolveWith(linkAddress)
+ }
+ if afterResolve != nil {
+ afterResolve()
+ }
+ r.Release()
+ }
+
if neigh := r.outgoingNIC.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
+ _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = entry.LinkAddr
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
+ _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = linkAddr
return nil, nil
}
-// RemoveWaker removes a waker that has been added in Resolve().
-func (r *Route) RemoveWaker(waker *sleep.Waker) {
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
- }
-
- if neigh := r.outgoingNIC.neigh; neigh != nil {
- neigh.removeWaker(nextAddr, waker)
- return
- }
-
- r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
-}
-
// local returns true if the route is a local route.
func (r *Route) local() bool {
return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
@@ -315,7 +363,13 @@ func (r *Route) local() bool {
//
// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isResolutionRequiredRLocked()
+}
+
+func (r *Route) isResolutionRequiredRLocked() bool {
+ if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
return false
}
@@ -323,11 +377,18 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isValidForOutgoing() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isValidForOutgoingRLocked()
+}
+
+func (r *Route) isValidForOutgoingRLocked() bool {
if !r.outgoingNIC.Enabled() {
return false
}
- if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
@@ -379,22 +440,31 @@ func (r *Route) MTU() uint32 {
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
-// Release frees all resources associated with the route.
+// Release decrements the reference counter of the resources associated with the
+// route.
func (r *Route) Release() {
- if r.localAddressEndpoint != nil {
- r.localAddressEndpoint.DecRef()
- r.localAddressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ ep.DecRef()
}
}
-// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.localAddressEndpoint != nil {
- if !r.localAddressEndpoint.IncRef() {
+// Acquire increments the reference counter of the resources associated with the
+// route.
+func (r *Route) Acquire() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ r.acquireLocked()
+}
+
+func (r *Route) acquireLocked() {
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ if !ep.IncRef() {
panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
}
}
- return *r
}
// Stack returns the instance of the Stack that owns this route.
@@ -407,7 +477,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.localAddressEndpoint.Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a2d234e7d..114643b03 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -29,7 +29,6 @@ import (
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -171,6 +170,9 @@ type TCPSenderState struct {
// Outstanding is the number of packets in flight.
Outstanding int
+ // SackedOut is the number of packets which have been selectively acked.
+ SackedOut int
+
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
@@ -1218,10 +1220,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
- return Route{}, false
+ return nil
}
var outgoingNIC *NIC
@@ -1245,7 +1247,7 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// route.
if outgoingNIC == nil {
localAddressEndpoint.DecRef()
- return Route{}, false
+ return nil
}
r := makeLocalRoute(
@@ -1259,10 +1261,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
if r.IsOutboundBroadcast() {
r.Release()
- return Route{}, false
+ return nil
}
- return r, true
+ return r
}
// findLocalRouteRLocked returns a local route.
@@ -1271,26 +1273,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// is, a local route is a route where packets never have to leave the stack.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
if len(localAddr) == 0 {
localAddr = remoteAddr
}
if localAddressNICID == 0 {
for _, localAddressNIC := range s.nics {
- if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
- return r, true
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
}
}
- return Route{}, false
+ return nil
}
if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
}
- return Route{}, false
+ return nil
}
// FindRoute creates a route to the given destination address, leaving through
@@ -1304,7 +1306,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1315,7 +1317,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
if s.handleLocal && !isMulticast && !isLocalBroadcast {
- if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
return r, nil
}
}
@@ -1339,9 +1341,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1365,7 +1367,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
gateway = route.Gateway
}
r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
- if r == (Route{}) {
+ if r == nil {
panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
}
return r, nil
@@ -1401,13 +1403,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 {
if aNIC, ok := s.nics[id]; ok {
if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if id == 0 {
@@ -1419,7 +1421,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
@@ -1427,12 +1429,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1517,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
nic := s.nics[nicID]
if nic == nil {
@@ -1528,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve)
}
// Neighbors returns all IP to MAC address associations.
@@ -1544,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
return nic.neighbors()
}
-// RemoveWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
- if s.useNeighborCache {
- s.mu.RLock()
- nic, ok := s.nics[nicID]
- s.mu.RUnlock()
-
- if ok {
- nic.removeWaker(addr, waker)
- }
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic := s.nics[nicID]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.removeWaker(fullAddr, waker)
- }
-}
-
// AddStaticNeighbor statically associates an IP address to a MAC address.
func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
s.mu.RLock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9d2d0aa84..856ebf6d4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,7 +27,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -407,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -425,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -436,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1563,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
}
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", got, want)
}
if gotRoute.NextHop != wantRoute.NextHop {
return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
@@ -1603,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = header.IPv4Any
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1657,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1667,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic2Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1683,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2727,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- nicID = 1
- lifetimeSeconds = 9999
+ globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01")
+ ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02")
+ toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ nicID = 1
+ lifetimeSeconds = 9999
)
prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2745,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix
nicAddrs []tcpip.Address
slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix
- connectAddr tcpip.Address
+ remoteAddr tcpip.Address
expectedLocalAddr tcpip.Address
}{
- // Test Rule 1 of RFC 6724 section 5.
+ // Test Rule 1 of RFC 6724 section 5 (prefer same address).
{
name: "Same Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Same Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
- // Test Rule 2 of RFC 6724 section 5.
+ // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope).
{
name: "Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
+
+ // Test Rule 6 of 6724 section 5 (prefer matching label).
{
name: "Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
+ {
+ name: "Toredo most preferred (first address)",
+ nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "Toredo most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "6To4 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "6To4 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
- // Test Rule 7 of RFC 6724 section 5.
+ // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses).
{
name: "Temp Global most preferred (last address)",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
{
name: "Temp Global most preferred (first address)",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
slaacPrefixForTempAddrAfterNICAddrAdd: prefix1,
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
+ // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix).
+ {
+ name: "Longest prefix matched most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr2, globalAddr1},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+ {
+ name: "Longest prefix matched most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, globalAddr2},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+
// Test returning the endpoint that is closest to the front when
// candidate addresses are "equal" from the perspective of RFC 6724
// section 5.
{
name: "Unique Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Link Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local for Unique Local",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: uniqueLocalAddr2,
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Temp Global for Global",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
slaacPrefixForTempAddrAfterNICAddrAdd: prefix2,
- connectAddr: globalAddr1,
+ remoteAddr: globalAddr1,
expectedLocalAddr: tempGlobalAddr2,
},
}
@@ -2899,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr3,
- NIC: nicID,
- }})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) {
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
@@ -2924,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
t.FailNow()
}
- if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ }
+
+ addressableEndpoint, ok := netEP.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("network endpoint is not addressable")
+ }
+
+ addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */)
+ if addressEP == nil {
+ t.Fatal("expected a non-nil address endpoint")
+ }
+ defer addressEP.DecRef()
+
+ if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr {
t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
}
})
@@ -3351,11 +3432,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
remNetSubnetBcast := remNetSubnet.Broadcast()
tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedRoute stack.Route
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedLocalAddress tcpip.Address
+ expectedRemoteAddress tcpip.Address
+ expectedRemoteLinkAddress tcpip.LinkAddress
+ expectedNextHop tcpip.Address
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLoop stack.PacketLooping
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3370,14 +3456,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: ipv4SubnetBcast,
- RemoteLinkAddress: header.EthernetBroadcastAddress,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut | stack.PacketLoop,
- },
+ remoteAddr: ipv4SubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: ipv4SubnetBcast,
+ expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut | stack.PacketLoop,
},
// Broadcast to a locally attached /31 subnet does not populate the
// broadcast MAC.
@@ -3393,13 +3477,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix31.Address,
- RemoteAddress: ipv4Subnet31Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedLocalAddress: ipv4AddrPrefix31.Address,
+ expectedRemoteAddress: ipv4Subnet31Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a locally attached /32 subnet does not populate the
// broadcast MAC.
@@ -3415,13 +3497,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix32.Address,
- RemoteAddress: ipv4Subnet32Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedLocalAddress: ipv4AddrPrefix32.Address,
+ expectedRemoteAddress: ipv4Subnet32Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// IPv6 has no notion of a broadcast.
{
@@ -3436,13 +3516,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv6Addr.Address,
- RemoteAddress: ipv6SubnetBcast,
- NetProto: header.IPv6ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv6SubnetBcast,
+ expectedLocalAddress: ipv6Addr.Address,
+ expectedRemoteAddress: ipv6SubnetBcast,
+ expectedNetProto: header.IPv6ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a remote subnet in the route table is send to the next-hop
// gateway.
@@ -3459,14 +3537,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to an unknown subnet follows the default route. Note that this
// is essentially just routing an unknown destination IP, because w/o any
@@ -3484,14 +3560,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
}
@@ -3520,10 +3594,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
+ if err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
- t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ if r.LocalAddress != test.expectedLocalAddress {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress)
+ }
+ if r.RemoteAddress != test.expectedRemoteAddress {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress)
+ }
+ if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
+ t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
+ }
+ if r.NextHop != test.expectedNextHop {
+ t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop)
+ }
+ if r.NetProto != test.expectedNetProto {
+ t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto)
+ }
+ if r.Loop != test.expectedLoop {
+ t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop)
}
})
}
@@ -4091,10 +4182,12 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
if err != test.findRouteErr {
t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
}
- defer r.Release()
if test.findRouteErr != nil {
return
@@ -4193,8 +4286,8 @@ func TestWritePacketToRemote(t *testing.T) {
if got, want := pkt.Proto, test.protocol; got != want {
t.Fatalf("pkt.Proto = %d, want %d", got, want)
}
- if got, want := pkt.Route.RemoteLinkAddress, linkAddr2; got != want {
- t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want)
+ if pkt.Route.RemoteLinkAddress != linkAddr2 {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
}
if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 2cdb5ca79..859278f0b 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "io/ioutil"
"math"
"math/rand"
"testing"
@@ -141,11 +142,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -308,9 +309,8 @@ func TestBindToDeviceDistribution(t *testing.T) {
defer ep.Close()
ep.SocketOptions().SetReusePort(endpoint.reuse)
- bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
- if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
- t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
+ if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
}
var dstAddr tcpip.Address
@@ -352,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
ep := <-pollChannel
- if _, _, err := ep.Read(nil); err != nil {
+ if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index fbac66993..a5facf578 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "io"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -42,7 +43,7 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.
@@ -65,6 +66,7 @@ func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
ep.ops.InitHandler(ep)
@@ -76,6 +78,7 @@ func (f *fakeTransportEndpoint) Abort() {
}
func (f *fakeTransportEndpoint) Close() {
+ // TODO(gvisor.dev/issue/5153): Consider retaining the route.
f.route.Release()
}
@@ -83,8 +86,8 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return mask
}
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return buffer.View{}, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ return tcpip.ReadResult{}, nil
}
func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -108,30 +111,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// SetSockOpt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
// SetSockOptInt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -155,16 +144,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return tcpip.ErrNoRoute
}
- defer r.Release()
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
+ r.Release()
return err
}
- f.route = r.Clone()
+ f.route = r
return nil
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 40f6e8aa9..f798056c0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "io"
"math/bits"
"reflect"
"strconv"
@@ -39,7 +40,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -49,8 +49,9 @@ const ipv4AddressSize = 4
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
-// Note: to support save / restore, it is important that all tcpip errors have
-// distinct error messages.
+// All errors must have unique msg strings.
+//
+// +stateify savable
type Error struct {
msg string
@@ -112,6 +113,7 @@ var (
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
ErrMalformedHeader = &Error{msg: "header is malformed"}
+ ErrBadBuffer = &Error{msg: "bad buffer"}
)
var messageToError map[string]*Error
@@ -161,6 +163,7 @@ func StringToError(s string) *Error {
ErrNotPermitted,
ErrAddressFamilyNotSupported,
ErrMalformedHeader,
+ ErrBadBuffer,
}
messageToError = make(map[string]*Error)
@@ -257,6 +260,44 @@ func (a Address) Unspecified() bool {
return true
}
+// MatchingPrefix returns the matching prefix length in bits.
+//
+// Panics if b and a have different lengths.
+func (a Address) MatchingPrefix(b Address) uint8 {
+ const bitsInAByte = 8
+
+ if len(a) != len(b) {
+ panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b))
+ }
+
+ var prefix uint8
+ for i := range a {
+ aByte := a[i]
+ bByte := b[i]
+
+ if aByte == bByte {
+ prefix += bitsInAByte
+ continue
+ }
+
+ // Count the remaining matching bits in the byte from MSbit to LSBbit.
+ mask := uint8(1) << (bitsInAByte - 1)
+ for {
+ if aByte&mask == bByte&mask {
+ prefix++
+ mask >>= 1
+ continue
+ }
+
+ break
+ }
+
+ break
+ }
+
+ return prefix
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -457,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) {
return s[:size], nil
}
+var _ io.Writer = (*SliceWriter)(nil)
+
+// SliceWriter implements io.Writer for slices.
+type SliceWriter []byte
+
+// Write implements io.Writer.Write.
+func (s *SliceWriter) Write(b []byte) (int, error) {
+ n := copy(*s, b)
+ *s = (*s)[n:]
+ if n < len(b) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
// A ControlMessages contains socket control messages for IP sockets.
//
// +stateify savable
@@ -491,6 +547,17 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+
+ // HasOriginalDestinationAddress indicates whether OriginalDstAddress is
+ // set.
+ HasOriginalDstAddress bool
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress FullAddress
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr *SockError
}
// PacketOwner is used to get UID and GID of the packet.
@@ -502,6 +569,40 @@ type PacketOwner interface {
GID() uint32
}
+// ReadOptions contains options for Endpoint.Read.
+type ReadOptions struct {
+ // Peek indicates whether this read is a peek.
+ Peek bool
+
+ // NeedRemoteAddr indicates whether to return the remote address, if
+ // supported.
+ NeedRemoteAddr bool
+
+ // NeedLinkPacketInfo indicates whether to return the link-layer information,
+ // if supported.
+ NeedLinkPacketInfo bool
+}
+
+// ReadResult represents result for a successful Endpoint.Read.
+type ReadResult struct {
+ // Count is the number of bytes received and written to the buffer.
+ Count int
+
+ // Total is the number of bytes of the received packet. This can be used to
+ // determine whether the read is truncated.
+ Total int
+
+ // ControlMessages is the control messages received.
+ ControlMessages ControlMessages
+
+ // RemoteAddr is the remote address if ReadOptions.NeedAddr is true.
+ RemoteAddr FullAddress
+
+ // LinkPacketInfo is the link-layer information of the received packet if
+ // ReadOptions.NeedLinkPacketInfo is true.
+ LinkPacketInfo LinkPacketInfo
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -516,11 +617,15 @@ type Endpoint interface {
// Abort is best effort; implementing Abort with Close is acceptable.
Abort()
- // Read reads data from the endpoint and optionally returns the sender.
+ // Read reads data from the endpoint and optionally writes to dst.
//
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+ // This method does not block if there is no data pending; in this case,
+ // ErrWouldBlock is returned.
+ //
+ // If non-zero number of bytes are successfully read and written to dst, err
+ // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
+ // should be returned.
+ Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
@@ -542,11 +647,6 @@ type Endpoint interface {
// not). The channel is only non-nil in this case.
Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
- // Peek reads data without consuming it from the endpoint.
- //
- // This method does not block if there is no data pending.
- Peek([][]byte) (int64, ControlMessages, *Error)
-
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
//
@@ -603,10 +703,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt SettableSocketOption) *Error
- // SetSockOptBool sets a socket option, for simple cases where a value
- // has the bool type.
- SetSockOptBool(opt SockOptBool, v bool) *Error
-
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
SetSockOptInt(opt SockOptInt, v int) *Error
@@ -614,10 +710,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt GettableSocketOption) *Error
- // GetSockOptBool gets a socket option for simple cases where a return
- // value has the bool type.
- GetSockOptBool(SockOptBool) (bool, *Error)
-
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
GetSockOptInt(SockOptInt) (int, *Error)
@@ -661,17 +753,6 @@ type LinkPacketInfo struct {
PktType PacketType
}
-// PacketEndpoint are additional methods that are only implemented by Packet
-// endpoints.
-type PacketEndpoint interface {
- // ReadPacket reads a datagram/packet from the endpoint and optionally
- // returns the sender and additional LinkPacketInfo.
- //
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
-}
-
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -704,53 +785,6 @@ type WriteOptions struct {
Atomic bool
}
-// SockOptBool represents socket options which values have the bool type.
-type SockOptBool int
-
-const (
- // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be held until segments are full by the TCP transport
- // protocol.
- CorkOption SockOptBool = iota
-
- // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be sent out immediately by the transport protocol. For
- // TCP, it determines if the Nagle algorithm is on or off.
- DelayOption
-
- // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether multicast packets sent over a non-loopback interface
- // will be looped back.
- MulticastLoopOption
-
- // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
- QuickAckOption
-
- // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
- // specify if the IPV6_TCLASS ancillary message is passed with incoming
- // packets.
- ReceiveTClassOption
-
- // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
- // if the TOS ancillary message is passed with incoming packets.
- ReceiveTOSOption
-
- // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
- // specify if more inforamtion is provided with incoming packets such as
- // interface index and address.
- ReceiveIPPacketInfoOption
-
- // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether an IPv6 socket is to be restricted to sending and receiving
- // IPv6 packets only.
- V6OnlyOption
-
- // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
- // endpoint that all packets being written have an IP header and the
- // endpoint should not attach an IP header.
- IPHdrIncludedOption
-)
-
// SockOptInt represents socket options which values have the int type.
type SockOptInt int
@@ -960,14 +994,6 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
-// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
-// should bind only on a specific NIC.
-type BindToDeviceOption NICID
-
-func (*BindToDeviceOption) isGettableSocketOption() {}
-
-func (*BindToDeviceOption) isSettableSocketOption() {}
-
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
@@ -1142,14 +1168,6 @@ type RemoveMembershipOption MembershipOption
func (*RemoveMembershipOption) isSettableSocketOption() {}
-// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP out-of-band data is delivered along with the normal in-band data.
-type OutOfBandInlineOption int
-
-func (*OutOfBandInlineOption) isGettableSocketOption() {}
-
-func (*OutOfBandInlineOption) isSettableSocketOption() {}
-
// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
// classic BPF filter on a given endpoint.
type SocketDetachFilterOption int
@@ -1199,10 +1217,6 @@ type LingerOption struct {
Timeout time.Duration
}
-func (*LingerOption) isGettableSocketOption() {}
-
-func (*LingerOption) isSettableSocketOption() {}
-
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
@@ -1373,6 +1387,18 @@ type ICMPv6PacketStats struct {
// RedirectMsg is the total number of ICMPv6 redirect message packets
// counted.
RedirectMsg *StatCounter
+
+ // MulticastListenerQuery is the total number of Multicast Listener Query
+ // messages counted.
+ MulticastListenerQuery *StatCounter
+
+ // MulticastListenerReport is the total number of Multicast Listener Report
+ // messages counted.
+ MulticastListenerReport *StatCounter
+
+ // MulticastListenerDone is the total number of Multicast Listener Done
+ // messages counted.
+ MulticastListenerDone *StatCounter
}
// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
@@ -1414,6 +1440,10 @@ type ICMPv6SentPacketStats struct {
type ICMPv6ReceivedPacketStats struct {
ICMPv6PacketStats
+ // Unrecognized is the total number of ICMPv6 packets received that the
+ // transport layer does not know how to parse.
+ Unrecognized *StatCounter
+
// Invalid is the total number of ICMPv6 packets received that the
// transport layer could not parse.
Invalid *StatCounter
@@ -1423,25 +1453,37 @@ type ICMPv6ReceivedPacketStats struct {
RouterOnlyPacketsDroppedByHost *StatCounter
}
-// ICMPStats collects ICMP-specific stats (both v4 and v6).
-type ICMPStats struct {
+// ICMPv4Stats collects ICMPv4-specific stats.
+type ICMPv4Stats struct {
// ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
// and a single count of packets which failed to write to the link
// layer.
- V4PacketsSent ICMPv4SentPacketStats
+ PacketsSent ICMPv4SentPacketStats
// ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
// packet type and a single count of invalid packets received.
- V4PacketsReceived ICMPv4ReceivedPacketStats
+ PacketsReceived ICMPv4ReceivedPacketStats
+}
+// ICMPv6Stats collects ICMPv6-specific stats.
+type ICMPv6Stats struct {
// ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
// and a single count of packets which failed to write to the link
// layer.
- V6PacketsSent ICMPv6SentPacketStats
+ PacketsSent ICMPv6SentPacketStats
// ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
// packet type and a single count of invalid packets received.
- V6PacketsReceived ICMPv6ReceivedPacketStats
+ PacketsReceived ICMPv6ReceivedPacketStats
+}
+
+// ICMPStats collects ICMP-specific stats (both v4 and v6).
+type ICMPStats struct {
+ // V4 contains the ICMPv4-specifics stats.
+ V4 ICMPv4Stats
+
+ // V6 contains the ICMPv4-specifics stats.
+ V6 ICMPv6Stats
}
// IGMPPacketStats enumerates counts for all IGMP packet types.
@@ -1465,8 +1507,7 @@ type IGMPPacketStats struct {
type IGMPSentPacketStats struct {
IGMPPacketStats
- // Dropped is the total number of IGMP packets dropped due to link layer
- // errors.
+ // Dropped is the total number of IGMP packets dropped.
Dropped *StatCounter
}
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index c461da137..9bd563c46 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -270,3 +270,43 @@ func TestAddressUnspecified(t *testing.T) {
})
}
}
+
+func TestAddressMatchingPrefix(t *testing.T) {
+ tests := []struct {
+ addrA Address
+ addrB Address
+ prefix uint8
+ }{
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x01",
+ prefix: 16,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x00",
+ prefix: 15,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x81\x00",
+ prefix: 0,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x80",
+ prefix: 8,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x02\x80",
+ prefix: 6,
+ },
+ }
+
+ for _, test := range tests {
+ if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix {
+ t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix)
+ }
+ }
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 800025fb9..ca1e88e99 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -15,10 +15,12 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/nested",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 39343b966..60054d6ef 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -15,13 +15,16 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -31,6 +34,33 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil)
+var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil)
+
+// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner
+// link endpoint and checks the destination link address before delivering
+// network packets to the network dispatcher.
+//
+// See ethernet.Endpoint for more details.
+func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck {
+ var e endpointWithDestinationCheck
+ e.Endpoint.Init(ethernet.New(ep), &e)
+ return &e
+}
+
+// endpointWithDestinationCheck is a link endpoint that checks the destination
+// link address before delivering network packets to the network dispatcher.
+type endpointWithDestinationCheck struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+}
+
func TestForwarding(t *testing.T) {
const (
host1NICID = 1
@@ -209,16 +239,16 @@ func TestForwarding(t *testing.T) {
host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
- if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil {
+ if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
}
- if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil {
+ if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
}
- if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
@@ -353,24 +383,33 @@ func TestForwarding(t *testing.T) {
// Wait for the endpoint to be readable.
<-ch
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, len(data), opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if diff := cmp.Diff(v, buffer.View(data)); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ RemoteAddr: tcpip.FullAddress{Addr: expectedFrom},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != expectedFrom {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom)
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
}
- return addr
+ return res.RemoteAddr
}
addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index bf8a1241f..209da3903 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -15,14 +15,14 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -87,21 +87,21 @@ func TestPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
- icmpBuf func(*testing.T) buffer.View
+ icmpBuf func(*testing.T) []byte
}{
{
name: "IPv4 Ping",
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) buffer.View {
+ icmpBuf: func(t *testing.T) []byte {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
hdr.SetType(header.ICMPv4Echo)
if n := copy(hdr.Payload(), data[:]); n != len(data) {
t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
}
- return buffer.View(hdr)
+ return hdr
},
},
{
@@ -109,14 +109,14 @@ func TestPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) buffer.View {
+ icmpBuf: func(t *testing.T) []byte {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
hdr.SetType(header.ICMPv6EchoRequest)
if n := copy(hdr.Payload(), data[:]); n != len(data) {
t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
}
- return buffer.View(hdr)
+ return hdr
},
},
}
@@ -133,10 +133,10 @@ func TestPing(t *testing.T) {
host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2)
- if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
@@ -201,16 +201,25 @@ func TestPing(t *testing.T) {
// Wait for the endpoint to be readable.
<-waiterCH
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, len(icmpBuf), opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err)
}
- if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != test.remoteAddr {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr)
+ if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
})
}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index baaa741cd..cf9e86c3c 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -15,12 +15,14 @@
package integration_test
import (
+ "bytes"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -238,21 +240,28 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want)
}
- var addr tcpip.FullAddress
- if gotPayload, _, err := rep.Read(&addr); test.expectRx {
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ if res, err := rep.Read(&buf, len(data), opts); test.expectRx {
if err != nil {
- t.Fatalf("reep.Read(_): %s", err)
- }
- if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if addr.Addr != test.addAddress.AddressWithPrefix.Address {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{
+ Addr: test.addAddress.AddressWithPrefix.Address,
+ },
+ }, res,
+ checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"),
+ ); diff != "" {
+ t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff)
}
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ } else if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
}
})
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 8be791a00..fae6c256a 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -15,12 +15,14 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -35,6 +37,9 @@ import (
const (
defaultMTU = 1280
ttl = 255
+
+ remotePort = 5555
+ localPort = 80
)
var (
@@ -96,11 +101,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -151,11 +156,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -215,163 +220,219 @@ func TestPingMulticastBroadcast(t *testing.T) {
}
+func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ totalLen := header.IPv4MinimumSize + payloadLen
+ hdr := buffer.NewPrependable(totalLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(udp.ProtocolNumber),
+ TTL: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
+func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLen),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
// multicast or broadcast address.
func TestIncomingMulticastAndBroadcast(t *testing.T) {
- const (
- nicID = 1
- remotePort = 5555
- localPort = 80
- )
+ const nicID = 1
data := []byte{1, 2, 3, 4}
- rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) {
- payloadLen := header.UDPMinimumSize + len(data)
- totalLen := header.IPv4MinimumSize + payloadLen
- hdr := buffer.NewPrependable(totalLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: remotePort,
- DstPort: localPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(udp.ProtocolNumber),
- TTL: ttl,
- SrcAddr: remoteIPv4Addr,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) {
- payloadLen := header.UDPMinimumSize + len(data)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: remotePort,
- DstPort: localPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
tests := []struct {
- name string
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectRx bool
+ name string
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
+ rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectRx bool
}{
{
- name: "IPv4 unicast binding to unicast",
- bindAddr: ipv4Addr.Address,
- dstAddr: ipv4Addr.Address,
- expectRx: true,
+ name: "IPv4 unicast binding to unicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4Addr.Address,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
},
{
- name: "IPv4 unicast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: ipv4Addr.Address,
- expectRx: false,
+ name: "IPv4 unicast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4Addr.Address,
+ expectRx: false,
},
{
- name: "IPv4 unicast binding to wildcard",
- dstAddr: ipv4Addr.Address,
- expectRx: true,
+ name: "IPv4 unicast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
},
{
- name: "IPv4 directed broadcast binding to subnet broadcast",
- bindAddr: ipv4SubnetBcast,
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 directed broadcast binding to subnet broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 directed broadcast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: ipv4SubnetBcast,
- expectRx: false,
+ name: "IPv4 directed broadcast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: false,
},
{
- name: "IPv4 directed broadcast binding to wildcard",
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 directed broadcast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 broadcast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: true,
+ name: "IPv4 broadcast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: true,
},
{
- name: "IPv4 broadcast binding to subnet broadcast",
- bindAddr: ipv4SubnetBcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: false,
+ name: "IPv4 broadcast binding to subnet broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: false,
},
{
- name: "IPv4 broadcast binding to wildcard",
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 broadcast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to all-systems multicast",
- bindAddr: header.IPv4AllSystems,
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
+ name: "IPv4 all-systems multicast binding to all-systems multicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4AllSystems,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to wildcard",
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
+ name: "IPv4 all-systems multicast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to unicast",
- bindAddr: ipv4Addr.Address,
- dstAddr: header.IPv4AllSystems,
- expectRx: false,
+ name: "IPv4 all-systems multicast binding to unicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4Addr.Address,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: false,
},
// IPv6 has no notion of a broadcast.
{
- name: "IPv6 unicast binding to wildcard",
- dstAddr: ipv6Addr.Address,
- expectRx: true,
+ name: "IPv6 unicast binding to wildcard",
+ dstAddr: ipv6Addr.Address,
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ expectRx: true,
},
{
- name: "IPv6 broadcast-like address binding to wildcard",
- dstAddr: ipv6SubnetBcast,
- expectRx: false,
+ name: "IPv6 broadcast-like address binding to wildcard",
+ dstAddr: ipv6SubnetBcast,
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ expectRx: false,
},
}
@@ -385,52 +446,41 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
- }
- ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
- }
-
- var netproto tcpip.NetworkProtocolNumber
- var rxUDP func(*channel.Endpoint, tcpip.Address)
- switch l := len(test.dstAddr); l {
- case header.IPv4AddressSize:
- netproto = header.IPv4ProtocolNumber
- rxUDP = rxIPv4UDP
- case header.IPv6AddressSize:
- netproto = header.IPv6ProtocolNumber
- rxUDP = rxIPv6UDP
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
+ protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
}
defer ep.Close()
bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%+v): %s", bindAddr, err)
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
}
- rxUDP(e, test.dstAddr)
- if gotPayload, _, err := ep.Read(nil); test.expectRx {
+ test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
+ var buf bytes.Buffer
+ var opts tcpip.ReadOptions
+ if res, err := ep.Read(&buf, len(data), opts); test.expectRx {
if err != nil {
- t.Fatalf("Read(nil): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ } else if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
}
})
}
@@ -476,7 +526,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
},
}
if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -516,12 +566,12 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
bindAddr := tcpip.FullAddress{Port: localPort}
if bindWildcard {
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
}
} else {
bindAddr.Addr = test.broadcastAddr
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
}
}
@@ -547,9 +597,19 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
// Wait for the endpoint to become readable.
<-rep.ch
- if gotPayload, _, err := rep.ep.Read(nil); err != nil {
- t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
- } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ var buf bytes.Buffer
+ result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ if err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
+ continue
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff)
+ }
+ if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" {
t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
}
}
@@ -557,3 +617,153 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
})
}
}
+
+func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ data := []byte{1, 2, 3, 4}
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
+ rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
+ multicastAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast binding to unicast",
+ multicastAddr: "\xe0\x01\x02\x03",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ },
+ {
+ name: "IPv6 broadcast-like address binding to wildcard",
+ multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04",
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ specifyNICID bool
+ specifyNICAddr bool
+ }{
+ {
+ name: "Specify NIC ID and NIC address",
+ specifyNICID: true,
+ specifyNICAddr: true,
+ },
+ {
+ name: "Don't specify NIC ID or NIC address",
+ specifyNICID: false,
+ specifyNICAddr: false,
+ },
+ {
+ name: "Specify NIC ID but don't specify NIC address",
+ specifyNICID: true,
+ specifyNICAddr: false,
+ },
+ {
+ name: "Don't specify NIC ID but specify NIC address",
+ specifyNICID: false,
+ specifyNICAddr: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ }
+
+ // Set the route table so that UDP can find a NIC that is
+ // routable to the multicast address when the NIC isn't specified.
+ if !subTest.specifyNICID && !subTest.specifyNICAddr {
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+
+ memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr}
+ if subTest.specifyNICID {
+ memOpt.NIC = nicID
+ }
+ if subTest.specifyNICAddr {
+ memOpt.InterfaceAddr = test.localAddr.Address
+ }
+
+ // We should receive UDP packets to the group once we join the
+ // multicast group.
+ addOpt := tcpip.AddMembershipOption(memOpt)
+ if err := ep.SetSockOpt(&addOpt); err != nil {
+ t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
+ }
+ test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("ep.Read: %s", err)
+ } else {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // We should not receive UDP packets to the group once we leave
+ // the multicast group.
+ removeOpt := tcpip.RemoveMembershipOption(memOpt)
+ if err := ep.SetSockOpt(&removeOpt); err != nil {
+ t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
+ }
+ if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 02fc47015..52cf89b54 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -15,11 +15,14 @@
package integration_test
import (
+ "bytes"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -203,16 +206,25 @@ func TestLocalPing(t *testing.T) {
// Wait for the endpoint to become readable.
<-ch
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, math.MaxUint16, opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err)
}
- if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: test.localAddr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr)
+ if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
test.checkLinkEndpoint(t, e)
@@ -338,14 +350,27 @@ func TestLocalUDP(t *testing.T) {
<-serverCH
var clientAddr tcpip.FullAddress
- if v, _, err := server.Read(&clientAddr); err != nil {
+ var readBuf bytes.Buffer
+ if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("server.Read(_): %s", err)
} else {
- if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" {
- t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
+ clientAddr = read.RemoteAddr
+
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: readBuf.Len(),
+ Total: readBuf.Len(),
+ RemoteAddr: tcpip.FullAddress{
+ Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
+ },
+ }, read, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff)
}
- if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address {
- t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address)
+ if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" {
+ t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
@@ -367,15 +392,23 @@ func TestLocalUDP(t *testing.T) {
// Wait for the client endpoint to become readable.
<-clientCH
- var gotServerAddr tcpip.FullAddress
- if v, _, err := client.Read(&gotServerAddr); err != nil {
+ readBuf.Reset()
+ if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("client.Read(_): %s", err)
} else {
- if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" {
- t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: readBuf.Len(),
+ Total: readBuf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr},
+ }, read, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff)
}
- if gotServerAddr.Addr != serverAddr.Addr {
- t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr)
+ if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" {
+ t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 59ec54ca0..2eb4457df 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,8 @@
package icmp
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -72,11 +74,9 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -132,7 +132,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = stateClosed
@@ -145,13 +148,13 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -161,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = p.senderAddress
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -270,26 +287,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- if to == nil {
- route = &e.route
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in Route.Resolve()
- // call below.
- e.mu.RUnlock()
- defer e.mu.RLock()
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // Recheck state after lock was re-acquired.
- if e.state != stateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
- }
- } else {
+ route := e.route
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -313,7 +312,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
}
if route.IsResolutionRequired() {
@@ -345,27 +344,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.SocketDetachFilterOption:
- return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- }
- return nil
-}
-
-// SetSockOptBool sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
@@ -381,11 +361,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -423,16 +398,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
@@ -548,7 +514,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
@@ -563,11 +528,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
e.ID = id
- e.route = r.Clone()
+ e.route = r
e.RegisterNICID = nicID
e.state = stateConnected
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index e2c7a0d62..3ab060751 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -26,6 +26,7 @@ package packet
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -85,8 +86,6 @@ type endpoint struct {
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
boundNIC tcpip.NICID
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -162,8 +161,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.PacketEndpoint.ReadPacket.
-func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -175,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn
err = tcpip.ErrClosedForReceive
}
ep.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ if !opts.Peek {
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+ }
ep.rcvMu.Unlock()
- if addr != nil {
- *addr = packet.senderAddr
+ res := tcpip.ReadResult{
+ Total: packet.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: packet.timestampNS,
+ },
}
-
- if info != nil {
- *info = packet.packetInfo
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = packet.senderAddr
+ }
+ if opts.NeedLinkPacketInfo {
+ res.LinkPacketInfo = packet.packetInfo
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return ep.ReadPacket(addr, nil)
+ n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -205,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, tcpip.ErrInvalidOptionValue
}
-// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
func (*endpoint) Disconnect() *tcpip.Error {
@@ -306,26 +308,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- ep.mu.Lock()
- ep.linger = *v
- ep.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -379,23 +370,16 @@ func (ep *endpoint) LastError() *tcpip.Error {
return err
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- ep.mu.Lock()
- *o = ep.linger
- ep.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrNotSupported
- }
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (ep *endpoint) UpdateLastError(err *tcpip.Error) {
+ ep.lastErrorMu.Lock()
+ ep.lastError = err
+ ep.lastErrorMu.Unlock()
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrNotSupported
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ return tcpip.ErrNotSupported
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -549,8 +533,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b0b53b181..dd260535f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -65,7 +66,6 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
- hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -84,10 +84,8 @@ type endpoint struct {
bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -116,9 +114,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
- hdrIncluded: !associated,
}
e.ops.InitHandler(e)
+ e.ops.SetHeaderIncluded(!associated)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -173,9 +171,11 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- if e.connected {
+ e.connected = false
+
+ if e.route != nil {
e.route.Release()
- e.connected = false
+ e.route = nil
}
e.closed = true
@@ -191,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -203,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
pkt := e.rcvList.Front()
- e.rcvList.Remove(pkt)
- e.rcvBufSize -= pkt.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = pkt.senderAddr
+ res := tcpip.ReadResult{
+ Total: pkt.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: pkt.timestampNS,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = pkt.senderAddr
}
- return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
+ n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -226,6 +240,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrInvalidOptionValue
}
+ if opts.To != nil {
+ // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -255,24 +276,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
e.mu.RLock()
+ defer e.mu.RUnlock()
if e.closed {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
payloadBytes, err := p.FullPayload()
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -294,39 +313,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if e.route.IsResolutionRequired() {
- savedRoute := &e.route
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in finishWrite.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Make sure that the route didn't change during the
- // time we didn't hold the lock.
- if !e.connected || savedRoute != &e.route {
- e.mu.Unlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
- n, ch, err := e.finishWrite(payloadBytes, savedRoute)
- e.mu.Unlock()
- return n, ch, err
- }
-
- n, ch, err := e.finishWrite(payloadBytes, &e.route)
- e.mu.RUnlock()
- return n, ch, err
+ return e.finishWrite(payloadBytes, e.route)
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
@@ -334,13 +330,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
- n, ch, err := e.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
- e.mu.RUnlock()
return n, ch, err
}
@@ -359,7 +353,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
@@ -384,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
return int64(len(payloadBytes)), nil, nil
}
-// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() *tcpip.Error {
return tcpip.ErrNotSupported
@@ -396,6 +385,11 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -424,11 +418,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer route.Release()
if e.associated {
// Re-register the endpoint with the appropriate NIC.
if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ route.Release()
return err
}
e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
@@ -436,7 +430,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the route we've connected via.
- e.route = route.Clone()
+ e.route = route
e.connected = true
return nil
@@ -519,33 +513,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- e.hdrIncluded = v
- e.mu.Unlock()
- return nil
- }
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -592,30 +568,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- v := e.hdrIncluded
- e.mu.Unlock()
- return v, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -650,6 +603,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ e.mu.RLock()
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -662,6 +616,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// sockets.
if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
return
@@ -669,6 +624,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
@@ -680,11 +636,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// If bound to a NIC, only accept data for that NIC.
if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
// If bound to an address, only accept data for that address.
if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
}
@@ -693,6 +651,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// connected to.
if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
@@ -727,6 +686,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
if wasEmpty {
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 7d97cbdc7..4a7e1c039 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if e.connected {
var err *tcpip.Error
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
+ // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
+ // remote address. We used to pass e.remote.RemoteAddress which was
+ // effectively the empty address but since moving e.route to hold a pointer
+ // to a route instead of the route by value, we pass the empty address
+ // directly. Obviously this was always wrong since we should provide the
+ // remote address we were connected to, to properly restore the route.
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 3d8174a4f..7e81203ba 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -93,7 +93,7 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
- shard_count = 10,
+ shard_count = more_shards,
deps = [
":tcp",
"//pkg/rand",
@@ -112,6 +112,7 @@ go_test(
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/test/testutil",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6e5adc383..2d96a65bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
route.ResolveWith(s.remoteLinkAddr)
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6Only
+ n.ops.SetV6Only(l.v6Only)
n.ID = s.id
n.boundNICID = s.nicID
n.route = route
@@ -599,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
}
- if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
return err
}
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
@@ -752,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
- v6Only := e.v6only
+ v6Only := e.ops.GetV6Only()
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
@@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 88a632019..0dc710276 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -16,6 +16,7 @@ package tcp
import (
"encoding/binary"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -133,7 +134,7 @@ func FindWndScale(wnd seqnum.Size) int {
return 0
}
- max := seqnum.Size(0xffff)
+ max := seqnum.Size(math.MaxUint16)
s := 0
for wnd > max && s < header.MaxWndScale {
s++
@@ -300,7 +301,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if ttl == 0 {
ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -361,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -461,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error {
func (h *handshake) resolveRoute() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -469,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error {
// Initial action is to resolve route.
index := wakerForResolution
+ attemptedResolution := false
for {
switch index {
case wakerForResolution:
- if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
- if err == tcpip.ErrNoLinkAddress {
- h.ep.stats.SendErrors.NoLinkAddr.Increment()
- } else if err != nil {
+ if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock {
+ if err != nil {
h.ep.stats.SendErrors.NoRoute.Increment()
}
// Either success (err == nil) or failure.
return err
}
+ if attemptedResolution {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ return tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
// Resolution not completed. Keep trying...
case wakerForNotification:
n := h.ep.fetchNotifications()
if n&notifyClose != 0 {
- h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -547,7 +551,7 @@ func (h *handshake) start() *tcpip.Error {
}
h.sendSYNOpts = synOpts
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -562,7 +566,7 @@ func (h *handshake) start() *tcpip.Error {
// complete completes the TCP 3-way handshake initiated by h.start().
func (h *handshake) complete() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resendWaker := sleep.Waker{}
s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -596,7 +600,7 @@ func (h *handshake) complete() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -818,8 +822,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
data = data.Clone(nil)
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
mss := int(gso.MSS)
@@ -863,8 +867,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// network endpoint and under the provided identity.
func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
@@ -939,7 +943,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, tcpFields{
+ err := e.sendTCP(e.route, tcpFields{
id: e.ID,
ttl: e.ttl,
tos: e.sendTOS,
@@ -1078,7 +1082,7 @@ func (e *endpoint) transitionToStateCloseLocked() {
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
- if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
@@ -1511,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Initialize the sleeper based on the wakers in funcs.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
for i := range funcs {
s.AddWaker(funcs[i].w, i)
}
@@ -1635,7 +1639,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
}
extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
if newSyn {
- info := e.EndpointInfo.TransportEndpointInfo
+ info := e.TransportEndpointInfo
newID := info.ID
newID.RemoteAddress = ""
newID.RemotePort = 0
@@ -1698,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const notification = 2
const timeWaitDone = 3
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index a6f25896b..1d1b01a6c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -405,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) {
}
}
- // Make sure we get the same error when calling the original ep and the
- // new one. This validates that v4-mapped endpoints are still able to
- // query the V6Only flag, whereas pure v4 endpoints are not.
- _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
- t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
- }
-
// Check the peer address.
addr, err := nep.GetRemoteAddress()
if err != nil {
@@ -530,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
- nep, _, err := c.EP.Accept(&addr)
+ _, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept(&addr)
+ _, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -548,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) {
if addr.Addr != context.TestV6Addr {
t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
}
-
- // Make sure we can still query the v6 only status of the new endpoint,
- // that is, that it is in fact a v6 socket.
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
- }
}
func TestV4AcceptOnV4(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 64563a8ba..8f3981075 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"encoding/binary"
"fmt"
+ "io"
"math"
"runtime"
"strings"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
@@ -310,7 +310,8 @@ type Stats struct {
func (*Stats) IsEndpointStats() {}
// EndpointInfo holds useful information about a transport endpoint which
-// can be queried by monitoring tools.
+// can be queried by monitoring tools. This exists to allow tcp-only state to
+// be exposed.
//
// +stateify savable
type EndpointInfo struct {
@@ -392,15 +393,28 @@ type endpoint struct {
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
- // The following fields are used to manage the receive queue. The
- // protocol goroutine adds ready-for-delivery segments to rcvList,
- // which are returned by Read() calls to users.
+ // rcvReadMu synchronizes calls to Read.
//
- // Once the peer has closed its send side, rcvClosed is set to true
- // to indicate to users that no more data is coming.
+ // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // must be held during each read to ensure atomicity, so that multiple reads
+ // do not interleave.
+ //
+ // rcvReadMu should be held before holding mu.
+ rcvReadMu sync.Mutex `state:"nosave"`
+
+ // rcvListMu synchronizes access to rcvList.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
+ rcvListMu sync.Mutex `state:"nosave"`
+
+ // rcvList is the queue for ready-for-delivery segments.
+ //
+ // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
+ // and removing segments from list. A range of segment can be determined, then
+ // temporarily release mu and rcvListMu while processing the segment range.
+ // This allows new segments to be appended to the list while processing.
+ //
+ // rcvListMu must be held to append segments to list.
rcvList segmentList `state:"wait"`
rcvClosed bool
// rcvBufSize is the total size of the receive buffer.
@@ -440,9 +454,8 @@ type endpoint struct {
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
boundNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
- v6only bool
isConnectNotified bool
// h stores a reference to the current handshake state if the endpoint is in
@@ -502,32 +515,14 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // bindToDevice is set to the NIC on which to bind or disabled if 0.
- bindToDevice tcpip.NICID
-
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
delay uint32
- // cork holds back segments until full.
- //
- // cork is a boolean (0 is false) and must be accessed atomically.
- cork uint32
-
// scoreboard holds TCP SACK Scoreboard information for this endpoint.
scoreboard *SACKScoreboard
- // The options below aren't implemented, but we remember the user
- // settings because applications expect to be able to set/query these
- // options.
-
- // slowAck holds the negated state of quick ack. It is stubbed out and
- // does nothing.
- //
- // slowAck is a boolean (0 is false) and must be accessed atomically.
- slowAck uint32
-
// segmentQueue is used to hand received segments to the protocol
// goroutine. Segments are queued as long as the queue is not full,
// and dropped when it is.
@@ -689,9 +684,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -705,7 +697,7 @@ func (e *endpoint) UniqueID() uint64 {
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
// r, it will be used; otherwise, the maximum possible MSS will be used.
-func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
// TODO(b/143359391): Respect TCP Min and Max size.
maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
@@ -888,6 +880,8 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
maxSynRetries: DefaultSynRetries,
}
e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
+ e.ops.SetQuickAck(true)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -911,7 +905,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var de tcpip.TCPDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
- e.SetSockOptBool(tcpip.DelayOption, true)
+ e.ops.SetDelayOption(true)
}
var tcpLT tcpip.TCPLingerTimeoutOption
@@ -1053,7 +1047,8 @@ func (e *endpoint) Close() {
return
}
- if e.linger.Enabled && e.linger.Timeout == 0 {
+ linger := e.SocketOptions().GetLinger()
+ if linger.Enabled && linger.Timeout == 0 {
s := e.EndpointState()
isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
if isResetState {
@@ -1173,7 +1168,11 @@ func (e *endpoint) cleanupLocked() {
e.boundPortFlags = ports.Flags{}
e.boundDest = tcpip.FullAddress{}
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
+
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -1314,8 +1313,78 @@ func (e *endpoint) LastError() *tcpip.Error {
return e.lastErrorLocked()
}
-// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.LockUser()
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+ e.UnlockUser()
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ e.rcvReadMu.Lock()
+ defer e.rcvReadMu.Unlock()
+
+ // N.B. Here we get a range of segments to be processed. It is safe to not
+ // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // can remove segments from the list through commitRead().
+ first, last, serr := e.startRead()
+ if serr != nil {
+ if serr == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
+ return tcpip.ReadResult{}, serr
+ }
+
+ var err error
+ done := 0
+ s := first
+ for s != nil && done < count {
+ var n int
+ n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ // Book keeping first then error handling.
+
+ done += n
+
+ if opts.Peek {
+ // For peek, we use the (first, last) range of segment returned from
+ // startRead. We don't consume the receive buffer, so commitRead should
+ // not be called.
+ //
+ // N.B. It is important to use `last` to determine the last segment, since
+ // appending can happen while we process, and will lead to data race.
+ if s == last {
+ break
+ }
+ s = s.Next()
+ } else {
+ // N.B. commitRead() conveniently returns the next segment to read, after
+ // removing the data/segment that is read.
+ s = e.commitRead(n)
+ }
+
+ if err != nil {
+ break
+ }
+ }
+
+ // If something is read, we must report it. Report error when nothing is read.
+ if done == 0 && err != nil {
+ return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ }
+ return tcpip.ReadResult{
+ Count: done,
+ Total: done,
+ }, nil
+}
+
+// startRead checks that endpoint is in a readable state, and return the
+// inclusive range of segments that can be read.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1324,7 +1393,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1332,61 +1401,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
- e.rcvListMu.Unlock()
if s == StateError {
if err := e.hardErrorLocked(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return nil, nil, err
}
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
+ return nil, nil, tcpip.ErrNotConnected
}
- v, err := e.readLocked()
- e.rcvListMu.Unlock()
-
- if err == tcpip.ErrClosedForReceive {
- e.stats.ReadErrors.ReadClosed.Increment()
- }
- return v, tcpip.ControlMessages{}, err
-}
-
-func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return buffer.View{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
- return buffer.View{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
- s := e.rcvList.Front()
- views := s.data.Views()
- v := views[s.viewToDeliver]
- s.viewToDeliver++
+ return e.rcvList.Front(), e.rcvList.Back(), nil
+}
+
+// commitRead commits a read of done bytes and returns the next non-empty
+// segment to read. Data read from the segment must have also been removed from
+// the segment in order for this method to work correctly.
+//
+// It is performance critical to call commitRead frequently when servicing a big
+// Read request, so TCP can make progress timely. Right now, it is designed to
+// do this per segment read, hence this method conveniently returns the next
+// segment to read while holding the lock.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) commitRead(done int) *segment {
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
- var delta int
- if s.viewToDeliver >= len(views) {
+ memDelta := 0
+ s := e.rcvList.Front()
+ for s != nil && s.data.Size() == 0 {
e.rcvList.Remove(s)
- // We only free up receive buffer space when the segment is released as the
- // segment is still holding on to the views even though some views have been
- // read out to the user.
- delta = s.segMemSize()
+ // Memory is only considered released when the whole segment has been
+ // read.
+ memDelta += s.segMemSize()
s.decRef()
+ s = e.rcvList.Front()
}
+ e.rcvBufUsed -= done
- e.rcvBufUsed -= len(v)
- // If the window was small before this read and if the read freed up
- // enough buffer space, to either fit an aMSS or half a receive buffer
- // (whichever smaller), then notify the protocol goroutine to send a
- // window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ if memDelta > 0 {
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
}
- return v, nil
+ return e.rcvList.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
@@ -1504,64 +1581,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return queueAndSend()
}
-// Peek reads data without consuming it from the endpoint.
-//
-// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- e.LockUser()
- defer e.UnlockUser()
-
- // The endpoint can be read if it's connected, or if it's already closed
- // but has some pending unread data.
- if s := e.EndpointState(); !s.connected() && s != StateClose {
- if s == StateError {
- return 0, tcpip.ControlMessages{}, e.hardErrorLocked()
- }
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
- }
-
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
-
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
- e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
- }
- return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
- }
-
- // Make a copy of vec so we can modify the slide headers.
- vec = append([][]byte(nil), vec...)
-
- var num int64
- for s := e.rcvList.Front(); s != nil; s = s.Next() {
- views := s.data.Views()
-
- for i := s.viewToDeliver; i < len(views); i++ {
- v := views[i]
-
- for len(v) > 0 {
- if len(vec) == 0 {
- return num, tcpip.ControlMessages{}, nil
- }
- if len(vec[0]) == 0 {
- vec = vec[1:]
- continue
- }
-
- n := copy(vec[0], v)
- v = v[n:]
- vec[0] = vec[0][n:]
- num += int64(n)
- }
- }
- }
-
- return num, tcpip.ControlMessages{}, nil
-}
-
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
// Precondition: e.mu and e.rcvListMu must be held.
@@ -1650,56 +1669,20 @@ func (e *endpoint) OnKeepAliveSet(v bool) {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
}
-// SetSockOptBool sets a socket option.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
-
- case tcpip.CorkOption:
- e.LockUser()
- if !v {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- e.UnlockUser()
-
- case tcpip.DelayOption:
- if v {
- atomic.StoreUint32(&e.delay, 1)
- } else {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- }
-
- case tcpip.QuickAckOption:
- o := uint32(1)
- if v {
- o = 0
- }
- atomic.StoreUint32(&e.slowAck, o)
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- // We only allow this to be set when we're in the initial state.
- if e.EndpointState() != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.LockUser()
- e.v6only = v
- e.UnlockUser()
+// OnDelayOptionSet implements tcpip.SocketOptionsHandler.OnDelayOptionSet.
+func (e *endpoint) OnDelayOptionSet(v bool) {
+ if !v {
+ // Handle delayed data.
+ e.sndWaker.Assert()
}
+}
- return nil
+// OnCorkOptionSet implements tcpip.SocketOptionsHandler.OnCorkOptionSet.
+func (e *endpoint) OnCorkOptionSet(v bool) {
+ if !v {
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ }
}
// SetSockOptInt sets a socket option.
@@ -1859,18 +1842,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.LockUser()
- e.bindToDevice = id
- e.UnlockUser()
-
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(*v)
@@ -1883,9 +1861,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
-
case *tcpip.TCPUserTimeoutOption:
e.LockUser()
e.userTimeout = time.Duration(*v)
@@ -1954,11 +1929,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.LockUser()
- e.linger = *v
- e.UnlockUser()
-
default:
return nil
}
@@ -1981,47 +1951,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
-// IsListening implements tcpip.SocketOptionsHandler.IsListening.
-func (e *endpoint) IsListening() bool {
- e.LockUser()
- defer e.UnlockUser()
- return e.EndpointState() == StateListen
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
-
- case tcpip.CorkOption:
- return atomic.LoadUint32(&e.cork) != 0, nil
-
- case tcpip.DelayOption:
- return atomic.LoadUint32(&e.delay) != 0, nil
-
- case tcpip.QuickAckOption:
- v := atomic.LoadUint32(&e.slowAck) == 0
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.LockUser()
- v := e.v6only
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.MulticastLoopOption:
- return true, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -2100,11 +2029,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case *tcpip.BindToDeviceOption:
- e.LockUser()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.UnlockUser()
-
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.LockUser()
@@ -2132,10 +2056,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
e.UnlockUser()
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
- *o = 1
-
case *tcpip.CongestionControlOption:
e.LockUser()
*o = e.cc
@@ -2164,11 +2084,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
Port: port,
}
- case *tcpip.LingerOption:
- e.LockUser()
- *o = e.linger
- e.UnlockUser()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2178,7 +2093,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -2316,11 +2231,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
}
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
if err != tcpip.ErrPortInUse || !reuse {
return false, nil
}
@@ -2358,15 +2274,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
return false, nil
}
}
id := e.ID
id.LocalPort = p
- if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
if err == tcpip.ErrPortInUse {
return false, nil
}
@@ -2377,7 +2293,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// the selected port.
e.ID = id
e.isPortReserved = true
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
e.boundDest = addr
return true, nil
@@ -2388,7 +2304,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.isRegistered = true
e.setEndpointState(StateConnecting)
- e.route = r.Clone()
+ r.Acquire()
+ e.route = r
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -2712,7 +2629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// v6only set to false.
if netProto == header.IPv6ProtocolNumber {
stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
- alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4
if alsoBindToV4 {
netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
@@ -2729,7 +2646,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
id := e.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2740,7 +2658,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// demuxer. Further connected endpoints always have a remote
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
- if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
return false
}
return true
@@ -2749,7 +2667,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
// TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
e.boundNICID = nic
@@ -2813,6 +2731,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
return true
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ // Linux passes the payload with the TCP header. We don't know if the TCP
+ // header even exists, it may not for fragmented packets.
+ Payload: pkt.Data.ToView(),
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.notifyProtocolGoroutine(notifyError)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
@@ -2827,16 +2780,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.notifyProtocolGoroutine(notifyMTUChanged)
case stack.ControlNoRoute:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNoRoute
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
case stack.ControlNetworkUnreachable:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNetworkUnreachable
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
}
@@ -3094,6 +3041,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
Ssthresh: e.snd.sndSsthresh,
SndCAAckCount: e.snd.sndCAAckCount,
Outstanding: e.snd.outstanding,
+ SackedOut: e.snd.sackedOut,
SndWnd: e.snd.sndWnd,
SndUna: e.snd.sndUna,
SndNxt: e.snd.sndNxt,
@@ -3176,7 +3124,7 @@ func (e *endpoint) State() uint32 {
func (e *endpoint) Info() tcpip.EndpointInfo {
e.LockUser()
// Make a copy of the endpoint info.
- ret := e.EndpointInfo
+ ret := e.TransportEndpointInfo
e.UnlockUser()
return &ret
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2329aca4b..672159eed 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
ttl = route.DefaultTTL()
}
- return sendTCP(&route, tcpFields{
+ return sendTCP(route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 8e0b7c843..405a6dce7 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -16,6 +16,7 @@ package tcp
import (
"container/heap"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -48,6 +49,10 @@ type receiver struct {
rcvWndScale uint8
+ // prevBufused is the snapshot of endpoint rcvBufUsed taken when we
+ // advertise a receive window.
+ prevBufUsed int
+
closed bool
// pendingRcvdSegments is bounded by the receive buffer size of the
@@ -80,9 +85,9 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// outgoing packets, we should use what we have advertised for acceptability
// test.
scaledWindowSize := r.rcvWnd >> r.rcvWndScale
- if scaledWindowSize > 0xffff {
+ if scaledWindowSize > math.MaxUint16 {
// This is what we actually put in the Window field.
- scaledWindowSize = 0xffff
+ scaledWindowSize = math.MaxUint16
}
advertisedWindowSize := scaledWindowSize << r.rcvWndScale
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
@@ -106,6 +111,34 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) {
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
newWnd := r.ep.selectWindow()
curWnd := r.currentWindow()
+ unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt))
+ bufUsed := r.ep.receiveBufferUsed()
+
+ // Grow the right edge of the window only for payloads larger than the
+ // the segment overhead OR if the application is actively consuming data.
+ //
+ // Avoiding growing the right edge otherwise, addresses a situation below:
+ // An application has been slow in reading data and we have burst of
+ // incoming segments lengths < segment overhead. Here, our available free
+ // memory would reduce drastically when compared to the advertised receive
+ // window.
+ //
+ // For example: With incoming 512 bytes segments, segment overhead of
+ // 552 bytes (at the time of writing this comment), with receive window
+ // starting from 1MB and with rcvAdvWndScale being 1, buffer would reach 0
+ // when the curWnd is still 19436 bytes, because for every incoming segment
+ // newWnd would reduce by (552+512) >> rcvAdvWndScale (current value 1),
+ // while curWnd would reduce by 512 bytes.
+ // Such a situation causes us to keep tail dropping the incoming segments
+ // and never advertise zero receive window to the peer.
+ //
+ // Linux does a similar check for minimal sk_buff size (128):
+ // https://github.com/torvalds/linux/blob/d5beb3140f91b1c8a3d41b14d729aefa4dcc58bc/net/ipv4/tcp_input.c#L783
+ //
+ // Also, if the application is reading the data, we keep growing the right
+ // edge, as we are still advertising a window that we think can be serviced.
+ toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed
+
// Update rcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
@@ -115,7 +148,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// rcvWUP rcvNxt rcvAcc new rcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
// If the new window moves the right edge, then update rcvAcc.
r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
} else {
@@ -130,11 +163,22 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// receiver's estimated RTT.
r.rcvWnd = newWnd
r.rcvWUP = r.rcvNxt
+ r.prevBufUsed = bufUsed
scaledWnd := r.rcvWnd >> r.rcvWndScale
if scaledWnd == 0 {
// Increment a metric if we are advertising an actual zero window.
r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
+
+ // If we started off with a window larger than what can he held in
+ // the 16bit window field, we ceil the value to the max value.
+ if scaledWnd > math.MaxUint16 {
+ scaledWnd = seqnum.Size(math.MaxUint16)
+
+ // Ensure that the stashed receive window always reflects what
+ // is being advertised.
+ r.rcvWnd = scaledWnd << r.rcvWndScale
+ }
return r.rcvNxt, scaledWnd
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 2091989cc..c5a6d2fba 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -37,7 +37,7 @@ const (
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
-// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+// segment is mostly immutable, the only field allowed to change is data.
//
// +stateify savable
type segment struct {
@@ -60,10 +60,7 @@ type segment struct {
hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
- // viewToDeliver keeps track of the next View that should be
- // delivered by the Read endpoint.
- viewToDeliver int
+ views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
flags uint8
@@ -84,6 +81,9 @@ type segment struct {
// acked indicates if the segment has already been SACKed.
acked bool
+
+ // dataMemSize is the memory used by data initially.
+ dataMemSize int
}
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
@@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
+ s.dataMemSize = s.data.Size()
return s
}
@@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
}
+ s.dataMemSize = s.data.Size()
return s
}
@@ -127,12 +129,12 @@ func (s *segment) clone() *segment {
netProto: s.netProto,
nicID: s.nicID,
remoteLinkAddr: s.remoteLinkAddr,
- viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
ep: s.ep,
qFlags: s.qFlags,
+ dataMemSize: s.dataMemSize,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -204,7 +206,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return segSize + s.data.Size()
+ return SegSize + s.dataMemSize
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7dc2741a6..7422d8c02 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -24,16 +24,11 @@ import (
func (s *segment) saveData() buffer.VectorisedView {
// We cannot save s.data directly as s.data.views may alias to s.views,
// which is not allowed by state framework (in-struct pointer).
- v := make([]buffer.View, len(s.data.Views()))
- // For views already delivered, we cannot save them directly as they may
- // have already been sliced and saved elsewhere (e.g., readViews).
- for i := 0; i < s.viewToDeliver; i++ {
- v[i] = append([]byte(nil), s.data.Views()[i]...)
+ vs := make([]buffer.View, len(s.data.Views()))
+ for i, v := range s.data.Views() {
+ vs[i] = v
}
- for i := s.viewToDeliver; i < len(v); i++ {
- v[i] = s.data.Views()[i]
- }
- return buffer.NewVectorisedView(s.data.Size(), v)
+ return buffer.NewVectorisedView(s.data.Size(), vs)
}
// loadData is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
index 0ab7b8f56..392ff0859 100644
--- a/pkg/tcpip/transport/tcp/segment_unsafe.go
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -19,5 +19,6 @@ import (
)
const (
- segSize = int(unsafe.Sizeof(segment{}))
+ // SegSize is the minimal size of the segment overhead.
+ SegSize = int(unsafe.Sizeof(segment{}))
)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 5ed9f7ace..cc991aba6 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -18,7 +18,6 @@ import (
"fmt"
"math"
"sort"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -138,6 +137,9 @@ type sender struct {
// that have been sent but not yet acknowledged.
outstanding int
+ // sackedOut is the number of packets which are selectively acked.
+ sackedOut int
+
// sndWnd is the send window size.
sndWnd seqnum.Size
@@ -373,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
+ oldMSS := s.maxPayloadSize
s.maxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
@@ -395,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
+ nextSeg := s.writeNext
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
@@ -402,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
break
}
- if seg.data.Size() > m {
+ if nextSeg == s.writeNext && seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
- s.writeNext = seg
- break
+ nextSeg = seg
+ }
+
+ if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Update sackedOut for new maximum payload size.
+ s.sackedOut -= s.pCount(seg, oldMSS)
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
+ s.writeNext = nextSeg
s.sendData()
}
@@ -630,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool {
// pCount returns the number of packets in the segment. Due to GSO, a segment
// can be composed of multiple packets.
-func (s *sender) pCount(seg *segment) int {
+func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
size := seg.data.Size()
if size == 0 {
return 1
}
- return (size-1)/s.maxPayloadSize + 1
+ return (size-1)/maxPayloadSize + 1
}
// splitSeg splits a given segment at the size specified and inserts the
@@ -813,7 +823,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
if !nextTooBig && seg.data.Size() < available {
// Segment is not full.
- if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ if s.outstanding > 0 && s.ep.ops.GetDelayOption() {
// Nagle's algorithm. From Wikipedia:
// Nagle's algorithm works by
// combining a number of small
@@ -832,7 +842,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// send space and MSS.
// TODO(gvisor.dev/issue/2833): Drain the held segments after a
// timeout.
- if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
+ if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() {
return false
}
}
@@ -1024,7 +1034,7 @@ func (s *sender) sendData() {
break
}
dataSent = true
- s.outstanding += s.pCount(seg)
+ s.outstanding += s.pCount(seg, s.maxPayloadSize)
s.writeNext = seg.Next()
}
@@ -1039,6 +1049,7 @@ func (s *sender) enterRecovery() {
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
+ s.sackedOut = 0
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
@@ -1208,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
s.rc.detectReorder(seg)
seg.acked = true
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
seg = seg.Next()
}
@@ -1381,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg)
+ prevCount := s.pCount(seg, s.maxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg)
+ s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
break
}
@@ -1400,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.writeList.Remove(seg)
- // If SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg)
+ s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ } else {
+ s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ef7f5719f..faf0c0ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) {
expected++
}
}
+
+// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK
+// is received.
+func TestSACKUpdateSackedOut(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ ackNum := 0
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.SackedOut is what we expect.
+ if state.Sender.SackedOut != 2 && ackNum == 0 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
+ }
+
+ if state.Sender.SackedOut != 0 && ackNum == 1 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
+ }
+ if ackNum > 0 {
+ close(probeDone)
+ }
+ ackNum++
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ sendAndReceive(t, c, 8)
+
+ // ACK for [3-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
+ bytesRead := 2 * maxPayload
+ end := start.Add(seqnum.Size(bytesRead))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ bytesRead += 3 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 7124a715d..9fa4672d7 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -17,10 +17,12 @@ package tcp_test
import (
"bytes"
"fmt"
+ "io/ioutil"
"math"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -40,6 +42,64 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// endpointTester provides helper functions to test a tcpip.Endpoint.
+type endpointTester struct {
+ ep tcpip.Endpoint
+}
+
+// CheckReadError issues a read to the endpoint and checking for an error.
+func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
+ t.Helper()
+ res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{})
+ if got != want {
+ t.Fatalf("ep.Read = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" {
+ t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff)
+ }
+}
+
+// CheckRead issues a read to the endpoint and checking for a success, returning
+// the data read.
+func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ return buf.Bytes()
+}
+
+// CheckReadFull reads from the endpoint for exactly count bytes.
+func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ var done int
+ for done < count {
+ res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{})
+ if err == tcpip.ErrWouldBlock {
+ // Wait for receive to be notified.
+ select {
+ case <-notifyRead:
+ case <-time.After(timeout):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+ continue
+ } else if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ done += res.Count
+ }
+ return buf.Bytes()
+}
+
const (
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -264,7 +324,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) {
}
// Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
}
@@ -740,9 +800,7 @@ func TestSimpleReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -762,11 +820,7 @@ func TestSimpleReceive(t *testing.T) {
}
// Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1380,9 +1434,8 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.Cleanup()
c.Create(-1)
- bindToDevice := tcpip.BindToDeviceOption(test.device)
- if err := c.EP.SetSockOpt(&bindToDevice); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err)
+ if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err)
}
// Start connection attempt.
waitEntry, _ := waiter.NewChannelEntry(nil)
@@ -1493,14 +1546,11 @@ func TestSynSent(t *testing.T) {
t.Fatal("timed out waiting for packet to arrive")
}
+ ept := endpointTester{c.EP}
if test.reset {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ ept.CheckReadError(t, tcpip.ErrConnectionRefused)
} else {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
- }
+ ept.CheckReadError(t, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
@@ -1525,9 +1575,8 @@ func TestOutOfOrderReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1552,9 +1601,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send the first 3 bytes now.
c.SendPacket(data[:3], &context.Headers{
@@ -1567,24 +1614,7 @@ func TestOutOfOrderReceive(t *testing.T) {
})
// Receive data.
- read := make([]byte, 0, 6)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- if err == tcpip.ErrWouldBlock {
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
- continue
- }
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
+ read := ept.CheckReadFull(t, 6, ch, 5*time.Second)
// Check that we received the data in proper order.
if !bytes.Equal(data, read) {
@@ -1609,9 +1639,8 @@ func TestOutOfOrderFlood(t *testing.T) {
rcvBufSz := math.MaxUint16
c.CreateConnected(789, 30000, rcvBufSz)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1686,9 +1715,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1755,9 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1838,17 +1865,14 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
@@ -1866,10 +1890,8 @@ func TestFullWindowReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err := c.EP.Read(nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
// the provided buffer value by tcp.SegOverheadFactor to calculate the actual
@@ -1906,11 +1928,7 @@ func TestFullWindowReceive(t *testing.T) {
)
// Receive data and check it.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1932,6 +1950,85 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+// Test the stack receive window advertisement on receiving segments smaller than
+// segment overhead. It tests for the right edge of the window to not grow when
+// the endpoint is not being read from.
+func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize,
+ Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
+ }
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+
+ c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Bump up the receive buffer size such that, when the receive window grows,
+ // the scaled window exceeds maxUint16.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err)
+ }
+
+ // Keep the payload size < segment overhead and such that it is a multiple
+ // of the window scaled value. This enables the test to perform equality
+ // checks on the incoming receive window.
+ payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale))
+ payloadLen := seqnum.Size(len(payload))
+ iss := seqnum.Value(789)
+ seqNum := iss.Add(1)
+
+ // Send payload to the endpoint and return the advertised receive window
+ // from the endpoint.
+ getIncomingRcvWnd := func() uint32 {
+ c.SendPacket(payload, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ Flags: header.TCPFlagAck,
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(payloadLen)
+
+ pkt := c.GetPacket()
+ return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
+ }
+
+ // Read the advertised receive window with the ACK for payload.
+ rcvWnd := getIncomingRcvWnd()
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Read the data so that the subsequent ACK from the endpoint
+ // grows the right edge of the window.
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
+ t.Fatalf("c.EP.Read: %s", err)
+ }
+
+ // Check if we have received max uint16 as our advertised
+ // scaled window now after a read above.
+ maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
+ if got, want := getIncomingRcvWnd(), maxRcv; got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+}
+
func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -1950,9 +2047,9 @@ func TestNoWindowShrinking(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
+
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
seqNum := iss.Add(1)
@@ -1974,11 +2071,7 @@ func TestNoWindowShrinking(t *testing.T) {
}
// Read the 1 byte payload we just sent.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- if got, want := payload, v; !bytes.Equal(got, want) {
+ if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) {
t.Fatalf("got data: %v, want: %v", got, want)
}
@@ -2051,24 +2144,8 @@ func TestNoWindowShrinking(t *testing.T) {
),
)
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
// Receive data and check it.
- read := make([]byte, 0, rcvBufSize)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
-
+ read := ept.CheckReadFull(t, len(data), ch, 5*time.Second)
if !bytes.Equal(data, read) {
t.Fatalf("got data = %v, want = %v", read, data)
}
@@ -2492,11 +2569,11 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// we need to read at 3 packets.
sz := 0
for sz < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- sz += len(v)
+ sz += res.Count
}
checker.IPv4(t, c.GetPacket(),
@@ -2529,10 +2606,10 @@ func TestSegmentMerging(t *testing.T) {
{
"cork",
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, true)
+ ep.SocketOptions().SetCorkOption(true)
},
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, false)
+ ep.SocketOptions().SetCorkOption(false)
},
},
}
@@ -2624,7 +2701,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -2672,7 +2749,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -2705,7 +2782,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOptBool(tcpip.DelayOption, false)
+ c.EP.SocketOptions().SetDelayOption(false)
// Check that data is received.
second := c.GetPacket()
@@ -2742,8 +2819,8 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }},
}
for _, test := range tests {
@@ -3191,13 +3268,13 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, _, err := c.EP.Read(nil); err {
+ switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err {
case tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
}
break loop
case <-time.After(1 * time.Second):
@@ -4087,9 +4164,8 @@ func TestReadAfterClosedState(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -4147,35 +4223,31 @@ func TestReadAfterClosedState(t *testing.T) {
}
// Check that peek works.
- peekBuf := make([]byte, 10)
- n, _, err := c.EP.Peek([][]byte{peekBuf})
+ var peekBuf bytes.Buffer
+ res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
- peekBuf = peekBuf[:n]
- if !bytes.Equal(data, peekBuf) {
- t.Fatalf("got data = %v, want = %v", peekBuf, data)
+ if got, want := res.Count, len(data); got != want {
+ t.Fatalf("res.Count = %d, want %d", got, want)
}
-
- // Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
+ if !bytes.Equal(data, peekBuf.Bytes()) {
+ t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data)
}
+ // Receive data.
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
// Now that we drained the queue, check that functions fail with the
// right error code.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
-
- if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
}
}
@@ -4429,7 +4501,7 @@ func TestBindToDeviceOption(t *testing.T) {
name string
setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
+ getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
{"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
@@ -4439,15 +4511,13 @@ func TestBindToDeviceOption(t *testing.T) {
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ bindToDevice := int32(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
- } else if bindToDevice != testAction.getBindToDevice {
+ bindToDevice := ep.SocketOptions().GetBindToDevice()
+ if bindToDevice != testAction.getBindToDevice {
t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
}
})
@@ -4544,17 +4614,8 @@ func TestSelfConnect(t *testing.T) {
// Read back what was written.
wq.EventUnregister(&waitEntry)
wq.EventRegister(&waitEntry, waiter.EventIn)
- rd, _, err := ep.Read(nil)
- if err != nil {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
- <-notifyCh
- rd, _, err = ep.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
+ ept := endpointTester{ep}
+ rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second)
if !bytes.Equal(data, rd) {
t.Fatalf("got data = %v, want = %v", rd, data)
@@ -4642,13 +4703,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
case "dual":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(false)
default:
t.Fatalf("unknown network: '%s'", network)
}
@@ -5011,9 +5068,8 @@ func TestKeepalive(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
@@ -5102,9 +5158,7 @@ func TestKeepalive(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
@@ -5999,9 +6053,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
- }
+ ept := endpointTester{ep}
+ ept.CheckReadError(t, tcpip.ErrNotConnected)
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
@@ -6102,10 +6155,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Introduce a 25ms latency by delaying the first byte.
latency := 25 * time.Millisecond
time.Sleep(latency)
- rawEP.SendPacketWithTS([]byte{1}, tsVal)
+ // Send an initial payload with atleast segment overhead size. The receive
+ // window would not grow for smaller segments.
+ rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal)
pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
+
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
@@ -6153,7 +6209,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
- _, _, err := c.EP.Read(nil)
+ _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -6277,11 +6333,11 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// to happen before we measure the new window.
totalCopied := 0
for {
- b, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
- totalCopied += len(b)
+ totalCopied += res.Count
}
// Invoke the moderation API. This is required for auto-tuning
@@ -6378,10 +6434,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T
if err != nil {
t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
- gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
- }
+ gotDelayOption := ep.SocketOptions().GetDelayOption()
if gotDelayOption != wantDelayOption {
t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
}
@@ -7201,9 +7254,8 @@ func TestTCPUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
@@ -7246,9 +7298,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
@@ -7287,9 +7338,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
@@ -7346,11 +7395,11 @@ func TestIncreaseWindowOnRead(t *testing.T) {
// defaultMTU is a good enough estimate for the MSS used for this
// connection.
for read < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- read += len(v)
+ read += res.Count
}
// After reading > MSS worth of data, we surely crossed MSS. See the ack:
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 0f9ed06cd..9e02d467d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -105,11 +106,18 @@ func TestTimeStampEnabledConnect(t *testing.T) {
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
@@ -286,11 +294,18 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
}
// Issue a read and we should data.
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e6aa4fc4b..ee55f030c 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -592,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
+ c.EP.SocketOptions().SetV6Only(v6only)
}
// GetV6Packet reads a single packet from the link layer endpoint of the context
@@ -637,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- NextHeader: uint8(tcp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
+ TransportProtocol: tcp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 7ebae63d8..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,5 +58,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a7a405dcb..075de1db0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,8 +16,9 @@ package udp
import (
"fmt"
+ "io"
+ "sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -30,10 +31,11 @@ import (
// +stateify savable
type udpPacket struct {
udpPacketEntry
- senderAddress tcpip.FullAddress
- packetInfo tcpip.IPPacketInfo
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ senderAddress tcpip.FullAddress
+ destinationAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
@@ -95,20 +97,19 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
dstPort uint16
- v6only bool
ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
- multicastLoop bool
portFlags ports.Flags
- bindToDevice tcpip.NICID
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -122,17 +123,6 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
- // receiveTOS determines if the incoming IPv4 TOS header field is passed
- // as ancillary data to ControlMessages on Read.
- receiveTOS bool
-
- // receiveTClass determines if the incoming IPv6 TClass header field is
- // passed as ancillary data to ControlMessages on Read.
- receiveTClass bool
-
- // receiveIPPacketInfo determines if the packet info is returned by Read.
- receiveIPPacketInfo bool
-
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -154,9 +144,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -188,7 +175,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
@@ -196,6 +182,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -211,6 +198,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState() returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
// UniqueID implements stack.TransportEndpoint.UniqueID.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -225,6 +226,13 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+}
+
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
@@ -236,7 +244,7 @@ func (e *endpoint) Close() {
e.mu.Lock()
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
+ switch e.EndpointState() {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
@@ -259,10 +267,13 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
- e.state = StateClosed
+ e.setEndpointState(StateClosed)
e.mu.Unlock()
@@ -272,11 +283,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
e.rcvMu.Lock()
@@ -288,41 +298,54 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
- e.rcvMu.Unlock()
-
- if addr != nil {
- *addr = p.senderAddress
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
}
+ e.rcvMu.Unlock()
+ // Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
Timestamp: p.timestamp,
}
- e.mu.RLock()
- receiveTOS := e.receiveTOS
- receiveTClass := e.receiveTClass
- receiveIPPacketInfo := e.receiveIPPacketInfo
- e.mu.RUnlock()
- if receiveTOS {
+ if e.ops.GetReceiveTOS() {
cm.HasTOS = true
cm.TOS = p.tos
}
- if receiveTClass {
+ if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
- if receiveIPPacketInfo {
+ if e.ops.GetReceivePacketInfo() {
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
}
- return p.data.ToView(), cm, nil
+ if e.ops.GetReceiveOriginalDstAddress() {
+ cm.HasOriginalDstAddress = true
+ cm.OriginalDstAddress = p.destinationAddress
+ }
+
+ // Read Result
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: cm,
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
+ }
+
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -331,7 +354,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
//
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateConnected:
return false, nil
@@ -353,7 +376,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return true, nil
}
@@ -368,7 +391,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -385,9 +408,9 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
if err != nil {
- return stack.Route{}, 0, err
+ return nil, 0, err
}
return r, nicID, nil
}
@@ -455,36 +478,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
- var dstPort uint16
- if to == nil {
- route = &e.route
- dstPort = e.dstPort
- resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
- // Promote lock to exclusive if using a shared route, given that it may
- // need to change in Route.Resolve() call below.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- if err == nil && route.IsResolutionRequired() {
- ch, err = route.Resolve(waker)
- }
-
- e.mu.Unlock()
- e.mu.RLock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- return ch, err
- }
- } else {
+ route := e.route
+ dstPort := e.dstPort
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -512,9 +508,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
dstPort = dst.Port
- resolve = route.Resolve
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
@@ -522,7 +517,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if route.IsResolutionRequired() {
- if ch, err := resolve(nil); err != nil {
+ if ch, err := route.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -536,6 +531,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
+ so := e.SocketOptions()
+ if so.GetRecvError() {
+ so.QueueLocalErr(
+ tcpip.ErrMessageTooLong,
+ route.NetProto,
+ header.UDPMaximumPacketSize,
+ tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ Port: dstPort,
+ },
+ v,
+ )
+ }
return 0, nil, tcpip.ErrMessageTooLong
}
@@ -571,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()
@@ -590,53 +594,6 @@ func (e *endpoint) OnReusePortSet(v bool) {
e.mu.Unlock()
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.MulticastLoopOption:
- e.mu.Lock()
- e.multicastLoop = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTOSOption:
- e.mu.Lock()
- e.receiveTOS = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrNotSupported
- }
-
- e.mu.Lock()
- e.receiveTClass = v
- e.mu.Unlock()
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.Lock()
- e.receiveIPPacketInfo = v
- e.mu.Unlock()
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.v6only = v
- }
- return nil
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -710,6 +667,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
@@ -756,14 +717,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
nicID := v.NIC
- // The interface address is considered not-set if it is empty or contains
- // all-zeros. The former represent the zero-value in golang, the latter the
- // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
- allZeros := header.IPv4Any
- if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
+ if v.InterfaceAddr.Unspecified() {
if nicID == 0 {
- r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err == nil {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
nicID = r.NICID()
r.Release()
}
@@ -796,10 +752,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
nicID := v.NIC
- if v.InterfaceAddr == header.IPv4Any {
+ if v.InterfaceAddr.Unspecified() {
if nicID == 0 {
- r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err == nil {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
nicID = r.NICID()
r.Release()
}
@@ -826,75 +781,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
delete(e.multicastMemberships, memToRemove)
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.mu.Lock()
- e.bindToDevice = id
- e.mu.Unlock()
-
case *tcpip.SocketDetachFilterOption:
return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
}
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTOSOption:
- e.mu.RLock()
- v := e.receiveTOS
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrNotSupported
- }
-
- e.mu.RLock()
- v := e.receiveTClass
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.RLock()
- v := e.receiveIPPacketInfo
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.RLock()
- v := e.v6only
- e.mu.RUnlock()
-
- return v, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -964,16 +856,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
e.mu.Unlock()
- case *tcpip.BindToDeviceOption:
- e.mu.RLock()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
-
- case *tcpip.LingerOption:
- e.mu.RLock()
- *o = e.linger
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1033,7 +915,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -1045,7 +927,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return nil
}
var (
@@ -1068,7 +950,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = StateBound
+ e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
@@ -1076,14 +958,14 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- e.state = StateInitial
+ e.setEndpointState(StateInitial)
}
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
- e.route = stack.Route{}
+ e.route = nil
e.dstPort = 0
return nil
@@ -1101,7 +983,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
var localPort uint16
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateBound, StateConnected:
localPort = e.ID.LocalPort
@@ -1127,7 +1009,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: e.ID.LocalAddress,
@@ -1136,7 +1017,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -1144,7 +1025,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
@@ -1155,6 +1036,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
@@ -1165,12 +1047,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
- e.route = r.Clone()
+ e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
- e.state = StateConnected
+ e.setEndpointState(StateConnected)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1192,7 +1074,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != StateBound && e.state != StateConnected {
+ if state := e.EndpointState(); state != StateBound && state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -1223,27 +1105,28 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
id.LocalPort = port
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1256,7 +1139,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// wildcard (empty) address, and this is an IPv6 endpoint with v6only
// set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
@@ -1287,7 +1170,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1319,7 +1202,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
addr := e.ID.LocalAddress
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
addr = e.route.LocalAddress
}
@@ -1335,7 +1218,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1434,6 +1317,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
Addr: id.RemoteAddress,
Port: header.UDP(hdr).SourcePort(),
},
+ destinationAddress: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: header.UDP(hdr).DestinationPort(),
+ },
}
packet.data = pkt.Data
e.rcvList.PushBack(packet)
@@ -1464,28 +1352,71 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ // Linux passes the payload without the UDP header.
+ var payload []byte
+ udp := header.UDP(pkt.Data.ToView())
+ if len(udp) >= header.UDPMinimumSize {
+ payload = udp.Payload()
+ }
+
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ Payload: payload,
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.waiterQueue.Notify(waiter.EventErr)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
- e.mu.RLock()
- if e.state == StateConnected {
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrConnectionRefused
- e.lastErrorMu.Unlock()
- e.mu.RUnlock()
-
- e.waiterQueue.Notify(waiter.EventErr)
+ if e.EndpointState() == StateConnected {
+ var errType byte
+ var errCode byte
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ errType = byte(header.ICMPv4DstUnreachable)
+ errCode = byte(header.ICMPv4PortUnreachable)
+ case header.IPv6ProtocolNumber:
+ errType = byte(header.ICMPv6DstUnreachable)
+ errCode = byte(header.ICMPv6PortUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
+ }
+ e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt)
return
}
- e.mu.RUnlock()
}
}
// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 99f3fc37f..13b72dc88 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != StateBound && e.state != StateConnected {
+ state := e.EndpointState()
+ if state != StateBound && state != StateConnected {
return
}
@@ -113,8 +114,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
+ if state == StateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 14e4648cd..d7fc21f11 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
route.ResolveWith(r.pkt.SourceLinkAddress())
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 1233bab14..455b8c2aa 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -18,10 +18,12 @@ import (
"bytes"
"context"
"fmt"
+ "io/ioutil"
"math/rand"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -363,9 +365,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetV6Only(true)
} else if flow.isBroadcast() {
c.ep.SocketOptions().SetBroadcast(true)
}
@@ -454,12 +454,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -556,7 +556,7 @@ func TestBindToDeviceOption(t *testing.T) {
name string
setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
+ getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
{"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
@@ -566,15 +566,13 @@ func TestBindToDeviceOption(t *testing.T) {
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ bindToDevice := int32(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
- } else if bindToDevice != testAction.getBindToDevice {
+ bindToDevice := ep.SocketOptions().GetBindToDevice()
+ if bindToDevice != testAction.getBindToDevice {
t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
}
})
@@ -599,13 +597,13 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- var addr tcpip.FullAddress
- v, cm, err := c.ep.Read(&addr)
+ var buf bytes.Buffer
+ res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, cm, err = c.ep.Read(&addr)
+ res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -625,23 +623,32 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
}
if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
}
- // Check the peer address.
+ // Check the read result.
h := flow.header4Tuple(incoming)
- if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages", // ControlMessages will be checked later.
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
}
// Check the payload.
+ v := buf.Bytes()
if !bytes.Equal(payload, v) {
c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
for _, f := range checkers {
- f(c.t, cm)
+ f(c.t, res.ControlMessages)
}
c.checkEndpointReadStats(1, epstats, err)
@@ -832,8 +839,8 @@ func TestV4ReadSelfSource(t *testing.T) {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
- if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
+ if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr {
+ t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
}
})
}
@@ -1414,9 +1421,7 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
- t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
- }
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
NIC: 1,
@@ -1431,6 +1436,93 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
+func TestReadRecvOriginalDstAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedOriginalDstAddr tcpip.FullAddress
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%#v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
+ }
+ }
+
+ c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
+
+ testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1629,13 +1721,15 @@ func TestSetTClass(t *testing.T) {
}
func TestReceiveTosTClass(t *testing.T) {
+ const RcvTOSOpt = "ReceiveTosOption"
+ const RcvTClassOpt = "ReceiveTClassOption"
+
testCases := []struct {
- name string
- getReceiveOption tcpip.SockOptBool
- tests []testFlow
+ name string
+ tests []testFlow
}{
- {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
- {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
+ {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1644,29 +1738,32 @@ func TestReceiveTosTClass(t *testing.T) {
defer c.cleanup()
c.createEndpointForFlow(flow)
- option := testCase.getReceiveOption
name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ var optionGetter func() bool
+ var optionSetter func(bool)
+ switch name {
+ case RcvTOSOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTOS
+ optionSetter = c.ep.SocketOptions().SetReceiveTOS
+ case RcvTClassOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTClass
+ optionSetter = c.ep.SocketOptions().SetReceiveTClass
+ default:
+ t.Fatalf("unkown test variant: %s", name)
}
+
+ // Verify that setting and reading the option works.
+ v := optionGetter()
// Test for expected default value.
if v != false {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
}
want := true
- if err := c.ep.SetSockOptBool(option, want); err != nil {
- c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
- }
-
- got, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
- }
+ optionSetter(want)
+ got := optionGetter()
if got != want {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
}
@@ -1676,10 +1773,10 @@ func TestReceiveTosTClass(t *testing.T) {
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %s", err)
}
- switch option {
- case tcpip.ReceiveTClassOption:
+ switch name {
+ case RcvTClassOpt:
testRead(c, flow, checker.ReceiveTClass(testTOS))
- case tcpip.ReceiveTOSOption:
+ case RcvTOSOpt:
testRead(c, flow, checker.ReceiveTOS(testTOS))
default:
t.Fatalf("unknown test variant: %s", name)
@@ -1993,12 +2090,12 @@ func TestShortHeader(t *testing.T) {
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
index 70945f234..e41769017 100644
--- a/pkg/test/criutil/criutil.go
+++ b/pkg/test/criutil/criutil.go
@@ -54,14 +54,20 @@ func ResolvePath(executable string) string {
}
}
+ // Favor /usr/local/bin, if it exists.
+ localBin := fmt.Sprintf("/usr/local/bin/%s", executable)
+ if _, err := os.Stat(localBin); err == nil {
+ return localBin
+ }
+
// Try to find via the path.
- guess, err := exec.LookPath(executable)
+ guess, _ := exec.LookPath(executable)
if err == nil {
return guess
}
- // Return a default path.
- return fmt.Sprintf("/usr/local/bin/%s", executable)
+ // Return a bare path; this generates a suitable error.
+ return executable
}
// NewCrictl returns a Crictl configured with a timeout and an endpoint over
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 2bf0a22ff..7bacb70d3 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -55,11 +55,8 @@ type Container struct {
copyErr error
cleanups []func()
- // Profiles are profiles added to this container. They contain methods
- // that are run after Creation, Start, and Cleanup of this Container, along
- // a handle to restart the profile. Generally, tests/benchmarks using
- // profiles need to run as root.
- profiles []Profile
+ // profile is the profiling hook associated with this container.
+ profile *profile
}
// RunOpts are options for running a container.
@@ -105,22 +102,7 @@ type RunOpts struct {
Links []string
}
-// MakeContainer sets up the struct for a Docker container.
-//
-// Names of containers will be unique.
-// Containers will check flags for profiling requests.
-func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
- c := MakeNativeContainer(ctx, logger)
- c.runtime = *runtime
- if p := MakePprofFromFlags(c); p != nil {
- c.AddProfile(p)
- }
- return c
-}
-
-// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native
-// containers aren't profiled.
-func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container {
+func makeContainer(ctx context.Context, logger testutil.Logger, runtime string) *Container {
// Slashes are not allowed in container names.
name := testutil.RandomID(logger.Name())
name = strings.ReplaceAll(name, "/", "-")
@@ -132,29 +114,32 @@ func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container
return &Container{
logger: logger,
Name: name,
- runtime: "",
+ runtime: runtime,
client: client,
}
}
-// AddProfile adds a profile to this container.
-func (c *Container) AddProfile(p Profile) {
- c.profiles = append(c.profiles, p)
+// MakeContainer constructs a suitable Container object.
+//
+// The runtime used is determined by the runtime flag.
+//
+// Containers will check flags for profiling requests.
+func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ return makeContainer(ctx, logger, *runtime)
}
-// RestartProfiles calls Restart on all profiles for this container.
-func (c *Container) RestartProfiles() error {
- for _, profile := range c.profiles {
- if err := profile.Restart(c); err != nil {
- return err
- }
- }
- return nil
+// MakeNativeContainer constructs a suitable Container object.
+//
+// The runtime used will be the system default.
+//
+// Native containers aren't profiled.
+func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ return makeContainer(ctx, logger, "" /*runtime*/)
}
// Spawn is analogous to 'docker run -d'.
func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error {
- if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
+ if err := c.create(ctx, r.Image, c.config(r, args), c.hostConfig(r), nil); err != nil {
return err
}
return c.Start(ctx)
@@ -167,7 +152,7 @@ func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string)
config.Tty = true
config.OpenStdin = true
- if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil {
+ if err := c.CreateFrom(ctx, r.Image, config, hostconf, netconf); err != nil {
return Process{}, err
}
@@ -194,7 +179,7 @@ func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string)
// Run is analogous to 'docker run'.
func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) {
- if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
+ if err := c.create(ctx, r.Image, c.config(r, args), c.hostConfig(r), nil); err != nil {
return "", err
}
@@ -221,26 +206,26 @@ func (c *Container) MakeLink(target string) string {
}
// CreateFrom creates a container from the given configs.
-func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
- return c.create(ctx, conf, hostconf, netconf)
+func (c *Container) CreateFrom(ctx context.Context, profileImage string, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ return c.create(ctx, profileImage, conf, hostconf, netconf)
}
// Create is analogous to 'docker create'.
func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error {
- return c.create(ctx, c.config(r, args), c.hostConfig(r), nil)
+ return c.create(ctx, r.Image, c.config(r, args), c.hostConfig(r), nil)
}
-func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+func (c *Container) create(ctx context.Context, profileImage string, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ if c.runtime != "" {
+ // Use the image name as provided here; which normally represents the
+ // unmodified "basic/alpine" image name. This should be easy to grok.
+ c.profileInit(profileImage)
+ }
cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name)
if err != nil {
return err
}
c.id = cont.ID
- for _, profile := range c.profiles {
- if err := profile.OnCreate(c); err != nil {
- return fmt.Errorf("OnCreate method failed with: %v", err)
- }
- }
return nil
}
@@ -286,11 +271,13 @@ func (c *Container) Start(ctx context.Context) error {
if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil {
return fmt.Errorf("ContainerStart failed: %v", err)
}
- for _, profile := range c.profiles {
- if err := profile.OnStart(c); err != nil {
- return fmt.Errorf("OnStart method failed: %v", err)
+
+ if c.profile != nil {
+ if err := c.profile.Start(c); err != nil {
+ c.logger.Logf("profile.Start failed: %v", err)
}
}
+
return nil
}
@@ -442,6 +429,7 @@ func (c *Container) Status(ctx context.Context) (types.ContainerState, error) {
// Wait waits for the container to exit.
func (c *Container) Wait(ctx context.Context) error {
+ defer c.stopProfiling()
statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
select {
case err := <-errChan:
@@ -499,8 +487,20 @@ func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, t
}
}
+// stopProfiling stops profiling.
+func (c *Container) stopProfiling() {
+ if c.profile != nil {
+ if err := c.profile.Stop(c); err != nil {
+ // This most likely means that the runtime for the container
+ // was too short to connect and actually get a profile.
+ c.logger.Logf("warning: profile.Stop failed: %v", err)
+ }
+ }
+}
+
// Kill kills the container.
func (c *Container) Kill(ctx context.Context) error {
+ defer c.stopProfiling()
return c.client.ContainerKill(ctx, c.id, "")
}
@@ -517,14 +517,6 @@ func (c *Container) Remove(ctx context.Context) error {
// CleanUp kills and deletes the container (best effort).
func (c *Container) CleanUp(ctx context.Context) {
- // Execute profile cleanups before the container goes down.
- for _, profile := range c.profiles {
- profile.OnCleanUp(c)
- }
-
- // Forget profiles.
- c.profiles = nil
-
// Execute all cleanups. We execute cleanups here to close any
// open connections to the container before closing. Open connections
// can cause Kill and Remove to hang.
@@ -538,10 +530,12 @@ func (c *Container) CleanUp(ctx context.Context) {
// Just log; can't do anything here.
c.logger.Logf("error killing container %q: %v", c.Name, err)
}
+
// Remove the image.
if err := c.Remove(ctx); err != nil {
c.logger.Logf("error removing container %q: %v", c.Name, err)
}
+
// Forget all mounts.
c.mounts = nil
}
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
index 7027df1a5..a40005799 100644
--- a/pkg/test/dockerutil/dockerutil.go
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -49,15 +49,11 @@ var (
// pprofBaseDir allows the user to change the directory to which profiles are
// written. By default, profiles will appear under:
// /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof.
- pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)")
-
- // duration is the max duration `runsc debug` will run and capture profiles.
- // If the container's clean up method is called prior to duration, the
- // profiling process will be killed.
- duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds")
+ pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)")
+ pprofDuration = flag.Duration("pprof-duration", time.Hour, "profiling duration (automatically stopped at container exit)")
// The below flags enable each type of profile. Multiple profiles can be
- // enabled for each run.
+ // enabled for each run. The profile will be collected from the start.
pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug")
pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug")
pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug")
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
index 4c739c9e9..bf968acec 100644
--- a/pkg/test/dockerutil/exec.go
+++ b/pkg/test/dockerutil/exec.go
@@ -77,11 +77,6 @@ func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Proc
return Process{}, fmt.Errorf("exec attach failed with err: %v", err)
}
- if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil {
- hijack.Close()
- return Process{}, fmt.Errorf("exec start failed with err: %v", err)
- }
-
return Process{
container: c,
execid: resp.ID,
diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go
index 55f9496cd..5cad3e959 100644
--- a/pkg/test/dockerutil/profile.go
+++ b/pkg/test/dockerutil/profile.go
@@ -17,72 +17,64 @@ package dockerutil
import (
"context"
"fmt"
- "io"
"os"
"os/exec"
"path/filepath"
+ "syscall"
"time"
)
-// Profile represents profile-like operations on a container,
-// such as running perf or pprof. It is meant to be added to containers
-// such that the container type calls the Profile during its lifecycle.
-type Profile interface {
- // OnCreate is called just after the container is created when the container
- // has a valid ID (e.g. c.ID()).
- OnCreate(c *Container) error
-
- // OnStart is called just after the container is started when the container
- // has a valid Pid (e.g. c.SandboxPid()).
- OnStart(c *Container) error
-
- // Restart restarts the Profile on request.
- Restart(c *Container) error
-
- // OnCleanUp is called during the container's cleanup method.
- // Cleanups should just log errors if they have them.
- OnCleanUp(c *Container) error
-}
-
-// Pprof is for running profiles with 'runsc debug'. Pprof workloads
-// should be run as root and ONLY against runsc sandboxes. The runtime
-// should have --profile set as an option in /etc/docker/daemon.json in
-// order for profiling to work with Pprof.
-type Pprof struct {
- BasePath string // path to put profiles
- BlockProfile bool
- CPUProfile bool
- HeapProfile bool
- MutexProfile bool
- Duration time.Duration // duration to run profiler e.g. '10s' or '1m'.
- shouldRun bool
- cmd *exec.Cmd
- stdout io.ReadCloser
- stderr io.ReadCloser
+// profile represents profile-like operations on a container.
+//
+// It is meant to be added to containers such that the container type calls
+// the profile during its lifecycle. Standard implementations are below.
+
+// profile is for running profiles with 'runsc debug'.
+type profile struct {
+ BasePath string
+ Types []string
+ Duration time.Duration
+ cmd *exec.Cmd
}
-// MakePprofFromFlags makes a Pprof profile from flags.
-func MakePprofFromFlags(c *Container) *Pprof {
- if !(*pprofBlock || *pprofCPU || *pprofHeap || *pprofMutex) {
- return nil
+// profileInit initializes a profile object, if required.
+//
+// N.B. The profiling filename initialized here will use the *image*
+// name, and not the unique container name. This is intentional. Most
+// of the time, profiling will be used for benchmarks. Benchmarks will
+// be run iteratively until a sufficiently large N is reached. It is
+// useful in this context to overwrite previous runs, and generate a
+// single profile result for the final test.
+func (c *Container) profileInit(image string) {
+ if !*pprofBlock && !*pprofCPU && !*pprofMutex && !*pprofHeap {
+ return // Nothing to do.
+ }
+ c.profile = &profile{
+ BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.logger.Name(), image),
+ Duration: *pprofDuration,
+ }
+ if *pprofCPU {
+ c.profile.Types = append(c.profile.Types, "cpu")
}
- return &Pprof{
- BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name),
- BlockProfile: *pprofBlock,
- CPUProfile: *pprofCPU,
- HeapProfile: *pprofHeap,
- MutexProfile: *pprofMutex,
- Duration: *duration,
+ if *pprofHeap {
+ c.profile.Types = append(c.profile.Types, "heap")
+ }
+ if *pprofMutex {
+ c.profile.Types = append(c.profile.Types, "mutex")
+ }
+ if *pprofBlock {
+ c.profile.Types = append(c.profile.Types, "block")
}
}
-// OnCreate implements Profile.OnCreate.
-func (p *Pprof) OnCreate(c *Container) error {
- return os.MkdirAll(p.BasePath, 0755)
-}
+// createProcess creates the collection process.
+func (p *profile) createProcess(c *Container) error {
+ // Ensure our directory exists.
+ if err := os.MkdirAll(p.BasePath, 0755); err != nil {
+ return err
+ }
-// OnStart implements Profile.OnStart.
-func (p *Pprof) OnStart(c *Container) error {
+ // Find the runtime to invoke.
path, err := RuntimePath()
if err != nil {
return fmt.Errorf("failed to get runtime path: %v", err)
@@ -90,58 +82,63 @@ func (p *Pprof) OnStart(c *Container) error {
// The root directory of this container's runtime.
root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime)
- // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`.
+
+ // Format is `runsc --root=rootdir debug --profile-*=file --duration=24h containerID`.
args := []string{root, "debug"}
- args = append(args, p.makeProfileArgs(c)...)
+ for _, profileArg := range p.Types {
+ outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg))
+ args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath))
+ }
+ args = append(args, fmt.Sprintf("--duration=%s", p.Duration)) // Or until container exits.
+ args = append(args, fmt.Sprintf("--delay=%s", p.Duration)) // Ditto.
args = append(args, c.ID())
// Best effort wait until container is running.
for now := time.Now(); time.Since(now) < 5*time.Second; {
if status, err := c.Status(context.Background()); err != nil {
return fmt.Errorf("failed to get status with: %v", err)
-
} else if status.Running {
break
}
- time.Sleep(500 * time.Millisecond)
+ time.Sleep(100 * time.Millisecond)
}
p.cmd = exec.Command(path, args...)
+ p.cmd.Stderr = os.Stderr // Pass through errors.
if err := p.cmd.Start(); err != nil {
- return fmt.Errorf("process failed: %v", err)
+ return fmt.Errorf("start process failed: %v", err)
}
+
return nil
}
-// Restart implements Profile.Restart.
-func (p *Pprof) Restart(c *Container) error {
- p.OnCleanUp(c)
- return p.OnStart(c)
+// killProcess kills the process, if running.
+func (p *profile) killProcess() error {
+ if p.cmd != nil && p.cmd.Process != nil {
+ return p.cmd.Process.Signal(syscall.SIGTERM)
+ }
+ return nil
}
-// OnCleanUp implements Profile.OnCleanup
-func (p *Pprof) OnCleanUp(c *Container) error {
+// waitProcess waits for the process, if running.
+func (p *profile) waitProcess() error {
defer func() { p.cmd = nil }()
- if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() {
- return p.cmd.Process.Kill()
+ if p.cmd != nil {
+ return p.cmd.Wait()
}
return nil
}
-// makeProfileArgs turns Pprof fields into runsc debug flags.
-func (p *Pprof) makeProfileArgs(c *Container) []string {
- var ret []string
- if p.BlockProfile {
- ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof")))
- }
- if p.CPUProfile {
- ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof")))
- }
- if p.HeapProfile {
- ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof")))
- }
- if p.MutexProfile {
- ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof")))
+// Start is called when profiling is started.
+func (p *profile) Start(c *Container) error {
+ return p.createProcess(c)
+}
+
+// Stop is called when profiling is started.
+func (p *profile) Stop(c *Container) error {
+ killErr := p.killProcess()
+ waitErr := p.waitProcess()
+ if waitErr != nil && killErr != nil {
+ return killErr
}
- ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration))
- return ret
+ return waitErr // Ignore okay wait, err kill.
}
diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go
index 8c4ffe483..4fe9ce15c 100644
--- a/pkg/test/dockerutil/profile_test.go
+++ b/pkg/test/dockerutil/profile_test.go
@@ -17,6 +17,7 @@ package dockerutil
import (
"context"
"fmt"
+ "io/ioutil"
"os"
"path/filepath"
"testing"
@@ -25,52 +26,60 @@ import (
type testCase struct {
name string
- pprof Pprof
+ profile profile
expectedFiles []string
}
-func TestPprof(t *testing.T) {
+func TestProfile(t *testing.T) {
// Basepath and expected file names for each type of profile.
- basePath := "/tmp/test/profile"
+ tmpDir, err := ioutil.TempDir("", "")
+ if err != nil {
+ t.Fatalf("unable to create temporary directory: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ // All expected names.
+ basePath := tmpDir
block := "block.pprof"
cpu := "cpu.pprof"
- goprofle := "go.pprof"
heap := "heap.pprof"
mutex := "mutex.pprof"
testCases := []testCase{
{
- name: "Cpu",
- pprof: Pprof{
- BasePath: basePath,
- CPUProfile: true,
- Duration: 2 * time.Second,
+ name: "One",
+ profile: profile{
+ BasePath: basePath,
+ Types: []string{"cpu"},
+ Duration: 2 * time.Second,
},
expectedFiles: []string{cpu},
},
{
name: "All",
- pprof: Pprof{
- BasePath: basePath,
- BlockProfile: true,
- CPUProfile: true,
- HeapProfile: true,
- MutexProfile: true,
- Duration: 2 * time.Second,
+ profile: profile{
+ BasePath: basePath,
+ Types: []string{"block", "cpu", "heap", "mutex"},
+ Duration: 2 * time.Second,
},
- expectedFiles: []string{block, cpu, goprofle, heap, mutex},
+ expectedFiles: []string{block, cpu, heap, mutex},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
c := MakeContainer(ctx, t)
+
// Set basepath to include the container name so there are no conflicts.
- tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name)
- c.AddProfile(&tc.pprof)
+ localProfile := tc.profile // Copy it.
+ localProfile.BasePath = filepath.Join(localProfile.BasePath, tc.name)
+
+ // Set directly on the container, to avoid flags.
+ c.profile = &localProfile
func() {
defer c.CleanUp(ctx)
+
// Start a container.
if err := c.Spawn(ctx, RunOpts{
Image: "basic/alpine",
@@ -83,24 +92,24 @@ func TestPprof(t *testing.T) {
}
// End early if the expected files exist and have data.
- for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) {
- if err := checkFiles(tc); err == nil {
+ for start := time.Now(); time.Since(start) < localProfile.Duration; time.Sleep(100 * time.Millisecond) {
+ if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err == nil {
break
}
}
}()
// Check all expected files exist and have data.
- if err := checkFiles(tc); err != nil {
+ if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err != nil {
t.Fatalf(err.Error())
}
})
}
}
-func checkFiles(tc testCase) error {
- for _, file := range tc.expectedFiles {
- stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file))
+func checkFiles(basePath string, expectedFiles []string) error {
+ for _, file := range expectedFiles {
+ stat, err := os.Stat(filepath.Join(basePath, file))
if err != nil {
return fmt.Errorf("stat failed with: %v", err)
} else if stat.Size() < 1 {
diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD
index c4b131896..00600a2ad 100644
--- a/pkg/test/testutil/BUILD
+++ b/pkg/test/testutil/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "testutil",
testonly = 1,
srcs = [
+ "sh.go",
"testutil.go",
"testutil_runfiles.go",
],
@@ -15,6 +16,7 @@ go_library(
"//runsc/config",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_kr_pty//:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
)
diff --git a/pkg/test/testutil/sh.go b/pkg/test/testutil/sh.go
new file mode 100644
index 000000000..1c77562be
--- /dev/null
+++ b/pkg/test/testutil/sh.go
@@ -0,0 +1,515 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/kr/pty"
+)
+
+// Prompt is used as shell prompt.
+// It is meant to be unique enough to not be seen in command outputs.
+const Prompt = "PROMPT> "
+
+// Simplistic shell string escape.
+func shellEscape(s string) string {
+ // specialChars is used to determine whether s needs quoting at all.
+ const specialChars = "\\'\"`${[|&;<>()*?! \t\n"
+ // If s needs quoting, escapedChars is the set of characters that are
+ // escaped with a backslash.
+ const escapedChars = "\\\"$`"
+ if len(s) == 0 {
+ return "''"
+ }
+ if !strings.ContainsAny(s, specialChars) {
+ return s
+ }
+ var b bytes.Buffer
+ b.WriteString("\"")
+ for _, c := range s {
+ if strings.ContainsAny(string(c), escapedChars) {
+ b.WriteString("\\")
+ }
+ b.WriteRune(c)
+ }
+ b.WriteString("\"")
+ return b.String()
+}
+
+type byteOrError struct {
+ b byte
+ err error
+}
+
+// Shell manages a /bin/sh invocation with convenience functions to handle I/O.
+// The shell is run in its own interactive TTY and should present its prompt.
+type Shell struct {
+ // cmd is a reference to the underlying sh process.
+ cmd *exec.Cmd
+ // cmdFinished is closed when cmd exits.
+ cmdFinished chan struct{}
+
+ // echo is whether the shell will echo input back to us.
+ // This helps setting expectations of getting feedback of written bytes.
+ echo bool
+ // Control characters we expect to see in the shell.
+ controlCharIntr string
+ controlCharEOF string
+
+ // ptyMaster and ptyReplica are the TTY pair associated with the shell.
+ ptyMaster *os.File
+ ptyReplica *os.File
+ // readCh is a channel where everything read from ptyMaster is written.
+ readCh chan byteOrError
+
+ // logger is used for logging. It may be nil.
+ logger Logger
+}
+
+// cleanup kills the shell process and closes the TTY.
+// Users of this library get a reference to this function with NewShell.
+func (s *Shell) cleanup() {
+ s.logf("cleanup", "Shell cleanup started.")
+ if s.cmd.ProcessState == nil {
+ if err := s.cmd.Process.Kill(); err != nil {
+ s.logf("cleanup", "cannot kill shell process: %v", err)
+ }
+ // We don't log the error returned by Wait because the monitorExit
+ // goroutine will already do so.
+ s.cmd.Wait()
+ }
+ s.ptyReplica.Close()
+ s.ptyMaster.Close()
+ // Wait for monitorExit goroutine to write exit status to the debug log.
+ <-s.cmdFinished
+ // Empty out everything in the readCh, but don't wait too long for it.
+ var extraBytes bytes.Buffer
+ unreadTimeout := time.After(100 * time.Millisecond)
+unreadLoop:
+ for {
+ select {
+ case r, ok := <-s.readCh:
+ if !ok {
+ break unreadLoop
+ } else if r.err == nil {
+ extraBytes.WriteByte(r.b)
+ }
+ case <-unreadTimeout:
+ break unreadLoop
+ }
+ }
+ if extraBytes.Len() > 0 {
+ s.logIO("unread", extraBytes.Bytes(), nil)
+ }
+ s.logf("cleanup", "Shell cleanup complete.")
+}
+
+// logIO logs byte I/O to both standard logging and the test log, if provided.
+func (s *Shell) logIO(prefix string, b []byte, err error) {
+ var sb strings.Builder
+ if len(b) > 0 {
+ sb.WriteString(fmt.Sprintf("%q", b))
+ } else {
+ sb.WriteString("(nothing)")
+ }
+ if err != nil {
+ sb.WriteString(fmt.Sprintf(" [error: %v]", err))
+ }
+ s.logf(prefix, "%s", sb.String())
+}
+
+// logf logs something to both standard logging and the test log, if provided.
+func (s *Shell) logf(prefix, format string, values ...interface{}) {
+ if s.logger != nil {
+ s.logger.Logf("[%s] %s", prefix, fmt.Sprintf(format, values...))
+ }
+}
+
+// monitorExit waits for the shell process to exit and logs the exit result.
+func (s *Shell) monitorExit() {
+ if err := s.cmd.Wait(); err != nil {
+ s.logf("cmd", "shell process terminated: %v", err)
+ } else {
+ s.logf("cmd", "shell process terminated successfully")
+ }
+ close(s.cmdFinished)
+}
+
+// reader continuously reads the shell output and populates readCh.
+func (s *Shell) reader(ctx context.Context) {
+ b := make([]byte, 4096)
+ defer close(s.readCh)
+ for {
+ select {
+ case <-s.cmdFinished:
+ // Shell process terminated; stop trying to read.
+ return
+ case <-ctx.Done():
+ // Shell process will also have terminated in this case;
+ // stop trying to read.
+ // We don't print an error here because doing so would print this in the
+ // normal case where the context passed to NewShell is canceled at the
+ // end of a successful test.
+ return
+ default:
+ // Shell still running, try reading.
+ }
+ if got, err := s.ptyMaster.Read(b); err != nil {
+ s.readCh <- byteOrError{err: err}
+ if err == io.EOF {
+ return
+ }
+ } else {
+ for i := 0; i < got; i++ {
+ s.readCh <- byteOrError{b: b[i]}
+ }
+ }
+ }
+}
+
+// readByte reads a single byte, respecting the context.
+func (s *Shell) readByte(ctx context.Context) (byte, error) {
+ select {
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ case r := <-s.readCh:
+ return r.b, r.err
+ }
+}
+
+// readLoop reads as many bytes as possible until the context expires, b is
+// full, or a short time passes. It returns how many bytes it has successfully
+// read.
+func (s *Shell) readLoop(ctx context.Context, b []byte) (int, error) {
+ soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer soonCancel()
+ var i int
+ for i = 0; i < len(b) && soonCtx.Err() == nil; i++ {
+ next, err := s.readByte(soonCtx)
+ if err != nil {
+ if i > 0 {
+ s.logIO("read", b[:i-1], err)
+ } else {
+ s.logIO("read", nil, err)
+ }
+ return i, err
+ }
+ b[i] = next
+ }
+ s.logIO("read", b[:i], soonCtx.Err())
+ return i, soonCtx.Err()
+}
+
+// readLine reads a single line. Strips out all \r characters for convenience.
+// Upon error, it will still return what it has read so far.
+// It will also exit quickly if the line content it has read so far (without a
+// line break) matches `prompt`.
+func (s *Shell) readLine(ctx context.Context, prompt string) ([]byte, error) {
+ soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer soonCancel()
+ var lineData bytes.Buffer
+ var b byte
+ var err error
+ for soonCtx.Err() == nil && b != '\n' {
+ b, err = s.readByte(soonCtx)
+ if err != nil {
+ data := lineData.Bytes()
+ s.logIO("read", data, err)
+ return data, err
+ }
+ if b != '\r' {
+ lineData.WriteByte(b)
+ }
+ if bytes.Equal(lineData.Bytes(), []byte(prompt)) {
+ // Assume that there will not be any further output if we get the prompt.
+ // This avoids waiting for the read deadline just to read the prompt.
+ break
+ }
+ }
+ data := lineData.Bytes()
+ s.logIO("read", data, soonCtx.Err())
+ return data, soonCtx.Err()
+}
+
+// Expect verifies that the next `len(want)` bytes we read match `want`.
+func (s *Shell) Expect(ctx context.Context, want []byte) error {
+ errPrefix := fmt.Sprintf("want(%q)", want)
+ b := make([]byte, len(want))
+ got, err := s.readLoop(ctx, b)
+ if err != nil {
+ if ctx.Err() != nil {
+ return fmt.Errorf("%s: context done (%w), got: %q", errPrefix, err, b[:got])
+ }
+ return fmt.Errorf("%s: %w", errPrefix, err)
+ }
+ if got < len(want) {
+ return fmt.Errorf("%s: short read (read %d bytes, expected %d): %q", errPrefix, got, len(want), b[:got])
+ }
+ if !bytes.Equal(b, want) {
+ return fmt.Errorf("got %q want %q", b, want)
+ }
+ return nil
+}
+
+// ExpectString verifies that the next `len(want)` bytes we read match `want`.
+func (s *Shell) ExpectString(ctx context.Context, want string) error {
+ return s.Expect(ctx, []byte(want))
+}
+
+// ExpectPrompt verifies that the next few bytes we read are the shell prompt.
+func (s *Shell) ExpectPrompt(ctx context.Context) error {
+ return s.ExpectString(ctx, Prompt)
+}
+
+// ExpectEmptyLine verifies that the next few bytes we read are an empty line,
+// as defined by any number of carriage or line break characters.
+func (s *Shell) ExpectEmptyLine(ctx context.Context) error {
+ line, err := s.readLine(ctx, Prompt)
+ if err != nil {
+ return fmt.Errorf("cannot read line: %w", err)
+ }
+ if strings.Trim(string(line), "\r\n") != "" {
+ return fmt.Errorf("line was not empty: %q", line)
+ }
+ return nil
+}
+
+// ExpectLine verifies that the next `len(want)` bytes we read match `want`,
+// followed by carriage returns or newline characters.
+func (s *Shell) ExpectLine(ctx context.Context, want string) error {
+ if err := s.ExpectString(ctx, want); err != nil {
+ return err
+ }
+ if err := s.ExpectEmptyLine(ctx); err != nil {
+ return fmt.Errorf("ExpectLine(%q): no line break: %w", want, err)
+ }
+ return nil
+}
+
+// Write writes `b` to the shell and verifies that all of them get written.
+func (s *Shell) Write(b []byte) error {
+ written, err := s.ptyMaster.Write(b)
+ s.logIO("write", b[:written], err)
+ if err != nil {
+ return fmt.Errorf("write(%q): %w", b, err)
+ }
+ if written != len(b) {
+ return fmt.Errorf("write(%q): wrote %d of %d bytes (%q)", b, written, len(b), b[:written])
+ }
+ return nil
+}
+
+// WriteLine writes `line` (to which \n will be appended) to the shell.
+// If the shell is in `echo` mode, it will also check that we got these bytes
+// back to read.
+func (s *Shell) WriteLine(ctx context.Context, line string) error {
+ if err := s.Write([]byte(line + "\n")); err != nil {
+ return err
+ }
+ if s.echo {
+ // We expect to see everything we've typed.
+ if err := s.ExpectLine(ctx, line); err != nil {
+ return fmt.Errorf("echo: %w", err)
+ }
+ }
+ return nil
+}
+
+// StartCommand is a convenience wrapper for WriteLine that mimics entering a
+// command line and pressing Enter. It does some basic shell argument escaping.
+func (s *Shell) StartCommand(ctx context.Context, cmd ...string) error {
+ escaped := make([]string, len(cmd))
+ for i, arg := range cmd {
+ escaped[i] = shellEscape(arg)
+ }
+ return s.WriteLine(ctx, strings.Join(escaped, " "))
+}
+
+// GetCommandOutput gets all following bytes until the prompt is encountered.
+// This is useful for matching the output of a command.
+// All \r are removed for ease of matching.
+func (s *Shell) GetCommandOutput(ctx context.Context) ([]byte, error) {
+ return s.ReadUntil(ctx, Prompt)
+}
+
+// ReadUntil gets all following bytes until a certain line is encountered.
+// This final line is not returned as part of the output, but everything before
+// it (including the \n) is included.
+// This is useful for matching the output of a command.
+// All \r are removed for ease of matching.
+func (s *Shell) ReadUntil(ctx context.Context, finalLine string) ([]byte, error) {
+ var output bytes.Buffer
+ for ctx.Err() == nil {
+ line, err := s.readLine(ctx, finalLine)
+ if err != nil {
+ return nil, err
+ }
+ if bytes.Equal(line, []byte(finalLine)) {
+ break
+ }
+ // readLine ensures that `line` either matches `finalLine` or contains \n.
+ // Thus we can be confident that `line` has a \n here.
+ output.Write(line)
+ }
+ return output.Bytes(), ctx.Err()
+}
+
+// RunCommand is a convenience wrapper for StartCommand + GetCommandOutput.
+func (s *Shell) RunCommand(ctx context.Context, cmd ...string) ([]byte, error) {
+ if err := s.StartCommand(ctx, cmd...); err != nil {
+ return nil, err
+ }
+ return s.GetCommandOutput(ctx)
+}
+
+// RefreshSTTY interprets output from `stty -a` to check whether we are in echo
+// mode and other settings.
+// It will assume that any line matching `expectPrompt` means the end of
+// the `stty -a` output.
+// Why do this rather than using `tcgets`? Because this function can be used in
+// conjunction with sub-shell processes that can allocate their own TTYs.
+func (s *Shell) RefreshSTTY(ctx context.Context, expectPrompt string) error {
+ // Temporarily assume we will not get any output.
+ // If echo is actually on, we'll get the "stty -a" line as if it was command
+ // output. This is OK because we parse the output generously.
+ s.echo = false
+ if err := s.WriteLine(ctx, "stty -a"); err != nil {
+ return fmt.Errorf("could not run `stty -a`: %w", err)
+ }
+ sttyOutput, err := s.ReadUntil(ctx, expectPrompt)
+ if err != nil {
+ return fmt.Errorf("cannot get `stty -a` output: %w", err)
+ }
+
+ // Set default control characters in case we can't see them in the output.
+ s.controlCharIntr = "^C"
+ s.controlCharEOF = "^D"
+ // stty output has two general notations:
+ // `a = b;` (for control characters), and `option` vs `-option` (for boolean
+ // options). We parse both kinds here.
+ // For `a = b;`, `controlChar` contains `a`, and `previousToken` is used to
+ // set `controlChar` to `previousToken` when we see an "=" token.
+ var previousToken, controlChar string
+ for _, token := range strings.Fields(string(sttyOutput)) {
+ if controlChar != "" {
+ value := strings.TrimSuffix(token, ";")
+ switch controlChar {
+ case "intr":
+ s.controlCharIntr = value
+ case "eof":
+ s.controlCharEOF = value
+ }
+ controlChar = ""
+ } else {
+ switch token {
+ case "=":
+ controlChar = previousToken
+ case "-echo":
+ s.echo = false
+ case "echo":
+ s.echo = true
+ }
+ }
+ previousToken = token
+ }
+ s.logf("stty", "refreshed settings: echo=%v, intr=%q, eof=%q", s.echo, s.controlCharIntr, s.controlCharEOF)
+ return nil
+}
+
+// sendControlCode sends `code` to the shell and expects to see `repr`.
+// If `expectLinebreak` is true, it also expects to see a linebreak.
+func (s *Shell) sendControlCode(ctx context.Context, code byte, repr string, expectLinebreak bool) error {
+ if err := s.Write([]byte{code}); err != nil {
+ return fmt.Errorf("cannot send %q: %w", code, err)
+ }
+ if err := s.ExpectString(ctx, repr); err != nil {
+ return fmt.Errorf("did not see %s: %w", repr, err)
+ }
+ if expectLinebreak {
+ if err := s.ExpectEmptyLine(ctx); err != nil {
+ return fmt.Errorf("linebreak after %s: %v", repr, err)
+ }
+ }
+ return nil
+}
+
+// SendInterrupt sends the \x03 (Ctrl+C) control character to the shell.
+func (s *Shell) SendInterrupt(ctx context.Context, expectLinebreak bool) error {
+ return s.sendControlCode(ctx, 0x03, s.controlCharIntr, expectLinebreak)
+}
+
+// SendEOF sends the \x04 (Ctrl+D) control character to the shell.
+func (s *Shell) SendEOF(ctx context.Context, expectLinebreak bool) error {
+ return s.sendControlCode(ctx, 0x04, s.controlCharEOF, expectLinebreak)
+}
+
+// NewShell returns a new managed sh process along with a cleanup function.
+// The caller is expected to call this function once it no longer needs the
+// shell.
+// The optional passed-in logger will be used for logging.
+func NewShell(ctx context.Context, logger Logger) (*Shell, func(), error) {
+ ptyMaster, ptyReplica, err := pty.Open()
+ if err != nil {
+ return nil, nil, fmt.Errorf("cannot create PTY: %w", err)
+ }
+ cmd := exec.CommandContext(ctx, "/bin/sh", "--noprofile", "--norc", "-i")
+ cmd.Stdin = ptyReplica
+ cmd.Stdout = ptyReplica
+ cmd.Stderr = ptyReplica
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setsid: true,
+ Setctty: true,
+ Ctty: 0,
+ }
+ cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", Prompt))
+ if err := cmd.Start(); err != nil {
+ return nil, nil, fmt.Errorf("cannot start shell: %w", err)
+ }
+ s := &Shell{
+ cmd: cmd,
+ cmdFinished: make(chan struct{}),
+ ptyMaster: ptyMaster,
+ ptyReplica: ptyReplica,
+ readCh: make(chan byteOrError, 1<<20),
+ logger: logger,
+ }
+ s.logf("creation", "Shell spawned.")
+ go s.monitorExit()
+ go s.reader(ctx)
+ setupCtx, setupCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer setupCancel()
+ // We expect to see the prompt immediately on startup,
+ // since the shell is started in interactive mode.
+ if err := s.ExpectPrompt(setupCtx); err != nil {
+ s.cleanup()
+ return nil, nil, fmt.Errorf("did not get initial prompt: %w", err)
+ }
+ s.logf("creation", "Initial prompt observed.")
+ // Get initial TTY settings.
+ if err := s.RefreshSTTY(setupCtx, Prompt); err != nil {
+ s.cleanup()
+ return nil, nil, fmt.Errorf("cannot get initial STTY settings: %w", err)
+ }
+ return s, s.cleanup, nil
+}
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 976331230..fdd416b5e 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -48,7 +48,10 @@ import (
)
var (
- checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+ checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+ partition = flag.Int("partition", 1, "partition number, this is 1-indexed")
+ totalPartitions = flag.Int("total_partitions", 1, "total number of partitions")
+ isRunningWithHostNet = flag.Bool("hostnet", false, "whether test is running with hostnet")
)
// IsCheckpointSupported returns the relevant command line flag.
@@ -56,6 +59,11 @@ func IsCheckpointSupported() bool {
return *checkpoint
}
+// IsRunningWithHostNet returns the relevant command line flag.
+func IsRunningWithHostNet() bool {
+ return *isRunningWithHostNet
+}
+
// ImageByName mangles the image name used locally. This depends on the image
// build infrastructure in images/ and tools/vm.
func ImageByName(name string) string {
@@ -248,14 +256,25 @@ func writeSpec(dir string, spec *specs.Spec) error {
// idRandomSrc is a pseudo random generator used to in RandomID.
var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
+// idRandomSrcMtx is the mutex protecting idRandomSrc.Read from being used
+// concurrently in differnt goroutines.
+var idRandomSrcMtx sync.Mutex
+
// RandomID returns 20 random bytes following the given prefix.
func RandomID(prefix string) string {
// Read 20 random bytes.
b := make([]byte, 20)
+ // Rand.Read is not safe for concurrent use. Packetimpact tests can be run in
+ // parallel now, so we have to protect the Read with a mutex. Otherwise we'll
+ // run into name conflicts.
+ // https://golang.org/pkg/math/rand/#Rand.Read
+ idRandomSrcMtx.Lock()
// "[Read] always returns len(p) and a nil error." --godoc
if _, err := idRandomSrc.Read(b); err != nil {
+ idRandomSrcMtx.Unlock()
panic("rand.Read failed: " + err.Error())
}
+ idRandomSrcMtx.Unlock()
if prefix != "" {
prefix = prefix + "-"
}
@@ -510,7 +529,8 @@ func TouchShardStatusFile() error {
}
// TestIndicesForShard returns indices for this test shard based on the
-// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
+// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars, as well as
+// the passed partition flags.
//
// If either of the env vars are not present, then the function will return all
// tests. If there are more shards than there are tests, then the returned list
@@ -535,6 +555,11 @@ func TestIndicesForShard(numTests int) ([]int, error) {
}
}
+ // Combine with the partitions.
+ partitionSize := shardTotal
+ shardTotal = (*totalPartitions) * shardTotal
+ shardIndex = partitionSize*(*partition-1) + shardIndex
+
// Calculate!
var indices []int
numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
index 13b2ea314..0e9a829f6 100644
--- a/pkg/urpc/urpc.go
+++ b/pkg/urpc/urpc.go
@@ -170,6 +170,9 @@ type Server struct {
// methods is the set of server methods.
methods map[string]registeredMethod
+ // stoppers are all registered stoppers.
+ stoppers []Stopper
+
// clients is a map of clients.
clients map[*unet.Socket]clientState
@@ -195,6 +198,12 @@ func NewServerWithCallback(afterRPCCallback func()) *Server {
}
}
+// Stopper is an optional interface, that when implemented, allows an object
+// to have a callback executed when the server is shutting down.
+type Stopper interface {
+ Stop()
+}
+
// Register registers the given object as an RPC receiver.
//
// This functions is the same way as the built-in RPC package, but it does not
@@ -206,6 +215,7 @@ func (s *Server) Register(obj interface{}) {
defer s.mu.Unlock()
typ := reflect.TypeOf(obj)
+ stopper, hasStop := obj.(Stopper)
// If we got a pointer, deref it to the underlying object. We need this to
// obtain the name of the underlying type.
@@ -221,6 +231,10 @@ func (s *Server) Register(obj interface{}) {
// Can't be anonymous.
panic("type not named.")
}
+ if hasStop && method.Name == "Stop" {
+ s.stoppers = append(s.stoppers, stopper)
+ continue // Legal stop method.
+ }
prettyName := typDeref.Name() + "." + method.Name
if _, ok := s.methods[prettyName]; ok {
@@ -283,12 +297,10 @@ func (s *Server) handleOne(client *unet.Socket) error {
// Client is dead.
return err
}
+ if s.afterRPCCallback != nil {
+ defer s.afterRPCCallback()
+ }
- defer func() {
- if s.afterRPCCallback != nil {
- s.afterRPCCallback()
- }
- }()
// Explicitly close all these files after the call.
//
// This is also explicitly a reference to the files after the call,
@@ -450,6 +462,11 @@ func (s *Server) Stop() {
// Wait for all outstanding requests.
defer s.wg.Wait()
+ // Call any Stop callbacks.
+ for _, stopper := range s.stoppers {
+ stopper.Stop()
+ }
+
// Close all known clients.
s.mu.Lock()
defer s.mu.Unlock()
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 9b1e7a085..79db8895b 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -167,7 +167,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) {
return n, err
}
-// Writer implements io.Writer.Write.
+// Write implements io.Writer.Write.
func (rw *IOReadWriter) Write(src []byte) (int, error) {
n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts)
end, ok := rw.Addr.AddLength(uint64(n))
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 08519d986..83d4f893a 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -119,7 +119,10 @@ type EntryCallback interface {
// The callback is supposed to perform minimal work, and cannot call
// any method on the queue itself because it will be locked while the
// callback is running.
- Callback(e *Entry)
+ //
+ // The mask indicates the events that occurred and that the entry is
+ // interested in.
+ Callback(e *Entry, mask EventMask)
}
// Entry represents a waiter that can be add to the a wait queue. It can
@@ -140,7 +143,7 @@ type channelCallback struct {
}
// Callback implements EntryCallback.Callback.
-func (c *channelCallback) Callback(*Entry) {
+func (c *channelCallback) Callback(*Entry, EventMask) {
select {
case c.ch <- struct{}{}:
default:
@@ -193,8 +196,8 @@ func (q *Queue) EventUnregister(e *Entry) {
func (q *Queue) Notify(mask EventMask) {
q.mu.RLock()
for e := q.list.Front(); e != nil; e = e.Next() {
- if mask&e.mask != 0 {
- e.Callback.Callback(e)
+ if m := mask & e.mask; m != 0 {
+ e.Callback.Callback(e, m)
}
}
q.mu.RUnlock()
diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go
index c1b94a4f3..6928f28b4 100644
--- a/pkg/waiter/waiter_test.go
+++ b/pkg/waiter/waiter_test.go
@@ -20,12 +20,12 @@ import (
)
type callbackStub struct {
- f func(e *Entry)
+ f func(e *Entry, m EventMask)
}
// Callback implements EntryCallback.Callback.
-func (c *callbackStub) Callback(e *Entry) {
- c.f(e)
+func (c *callbackStub) Callback(e *Entry, m EventMask) {
+ c.f(e, m)
}
func TestEmptyQueue(t *testing.T) {
@@ -36,7 +36,7 @@ func TestEmptyQueue(t *testing.T) {
// Register then unregister a waiter, then notify the queue.
cnt := 0
- e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}}
q.EventRegister(&e, EventIn)
q.EventUnregister(&e)
q.Notify(EventIn)
@@ -49,7 +49,7 @@ func TestMask(t *testing.T) {
// Register a waiter.
var q Queue
var cnt int
- e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}}
q.EventRegister(&e, EventIn|EventErr)
// Notify with an overlapping mask.
@@ -101,11 +101,14 @@ func TestConcurrentRegistration(t *testing.T) {
for i := 0; i < concurrency; i++ {
go func() {
var e Entry
- e.Callback = &callbackStub{func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry, mask EventMask) {
cnt++
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
+ if mask != EventIn {
+ t.Errorf("mask = %#x want %#x", mask, EventIn)
+ }
}}
// Wait for notification, then register.
@@ -158,11 +161,14 @@ func TestConcurrentNotification(t *testing.T) {
// Register waiters.
for i := 0; i < waiterCount; i++ {
var e Entry
- e.Callback = &callbackStub{func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry, mask EventMask) {
atomic.AddInt32(&cnt, 1)
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
+ if mask != EventIn {
+ t.Errorf("mask = %#x want %#x", mask, EventIn)
+ }
}}
q.EventRegister(&e, EventIn|EventErr)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 8c73dc5dc..67307ab3c 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -33,6 +33,7 @@ go_library(
"//pkg/cpuid",
"//pkg/eventchannel",
"//pkg/fd",
+ "//pkg/flipcall",
"//pkg/fspath",
"//pkg/log",
"//pkg/memutil",
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 7076ae2e2..a3a76b609 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -53,7 +53,7 @@ type compatEmitter struct {
func newCompatEmitter(logFD int) (*compatEmitter, error) {
nameMap, ok := getSyscallNameMap()
if !ok {
- return nil, fmt.Errorf("Linux syscall table not found")
+ return nil, fmt.Errorf("syscall table not found")
}
c := &compatEmitter{
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 865126ac5..cb5d8ea31 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -104,13 +104,11 @@ const (
// Profiling related commands (see pprof.go for more details).
const (
- StartCPUProfile = "Profile.StartCPUProfile"
- StopCPUProfile = "Profile.StopCPUProfile"
- HeapProfile = "Profile.HeapProfile"
- BlockProfile = "Profile.BlockProfile"
- MutexProfile = "Profile.MutexProfile"
- StartTrace = "Profile.StartTrace"
- StopTrace = "Profile.StopTrace"
+ CPUProfile = "Profile.CPU"
+ HeapProfile = "Profile.Heap"
+ BlockProfile = "Profile.Block"
+ MutexProfile = "Profile.Mutex"
+ Trace = "Profile.Trace"
)
// Logging related commands (see logging.go for more details).
@@ -131,9 +129,6 @@ type controller struct {
// manager holds the containerManager methods.
manager *containerManager
-
- // pprop holds the profile instance if enabled. It may be nil.
- pprof *control.Profile
}
// newController creates a new controller. The caller must call
@@ -164,19 +159,14 @@ func newController(fd int, l *Loader) (*controller, error) {
ctrl.srv.Register(&control.Logging{})
if l.root.conf.ProfileEnable {
- ctrl.pprof = &control.Profile{Kernel: l.k}
- ctrl.srv.Register(ctrl.pprof)
+ ctrl.srv.Register(control.NewProfile(l.k))
}
return ctrl, nil
}
func (c *controller) stop() {
- if c.pprof != nil {
- // These are noop if there is nothing being profiled.
- _ = c.pprof.StopCPUProfile(nil, nil)
- _ = c.pprof.StopTrace(nil, nil)
- }
+ c.srv.Stop()
}
// containerManager manages sandbox containers.
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index a7c4ebb0c..eacd73531 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -343,6 +343,21 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_PKTINFO),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVERR),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_IPV6),
seccomp.EqualTo(syscall.IPV6_TCLASS),
},
@@ -354,10 +369,20 @@ func hostInetFilters() seccomp.SyscallRules {
{
seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_RECVERR),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
seccomp.EqualTo(syscall.IPV6_V6ONLY),
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_SOCKET),
seccomp.EqualTo(syscall.SO_ERROR),
},
@@ -393,6 +418,11 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_TIMESTAMP),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_TCP),
seccomp.EqualTo(syscall.TCP_NODELAY),
},
@@ -401,6 +431,11 @@ func hostInetFilters() seccomp.SyscallRules {
seccomp.EqualTo(syscall.SOL_TCP),
seccomp.EqualTo(syscall.TCP_INFO),
},
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(linux.TCP_INQ),
+ },
},
syscall.SYS_IOCTL: []seccomp.Rule{
{
@@ -421,29 +456,29 @@ func hostInetFilters() seccomp.SyscallRules {
syscall.SYS_SETSOCKOPT: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_V6ONLY),
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_SNDBUF),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_SNDBUF),
+ seccomp.EqualTo(syscall.SO_RCVBUF),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_RCVBUF),
+ seccomp.EqualTo(syscall.SO_REUSEADDR),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_REUSEADDR),
+ seccomp.EqualTo(syscall.SO_TIMESTAMP),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
@@ -456,6 +491,13 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(linux.TCP_INQ),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_IP),
seccomp.EqualTo(syscall.IP_TOS),
seccomp.MatchAny{},
@@ -470,6 +512,27 @@ func hostInetFilters() seccomp.SyscallRules {
},
{
seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_PKTINFO),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVERR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
seccomp.EqualTo(syscall.SOL_IPV6),
seccomp.EqualTo(syscall.IPV6_TCLASS),
seccomp.MatchAny{},
@@ -482,6 +545,27 @@ func hostInetFilters() seccomp.SyscallRules {
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_RECVERR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_V6ONLY),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
+ },
},
syscall.SYS_SHUTDOWN: []seccomp.Rule{
{
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 86bdc6ae3..f41d6c665 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -294,7 +294,7 @@ func New(args Args) (*Loader, error) {
if args.TotalMem > 0 {
// Adjust the total memory returned by the Sentry so that applications that
// use /proc/meminfo can make allocations based on this limit.
- usage.MinimumTotalMemoryBytes = args.TotalMem
+ usage.MaximumTotalMemoryBytes = args.TotalMem
log.Infof("Setting total memory to %.2f GB", float64(args.TotalMem)/(1<<30))
}
@@ -598,7 +598,6 @@ func (l *Loader) run() error {
if err != nil {
return err
}
-
}
ep.tg = l.k.GlobalInit()
@@ -1045,9 +1044,10 @@ func (l *Loader) WaitExit() kernel.ExitStatus {
// Wait for container.
l.k.WaitExited()
- // Cleanup
+ // Stop the control server.
l.ctrl.stop()
+ // Check all references.
refs.OnExit()
return l.k.GlobalInit().ExitStatus()
@@ -1082,7 +1082,12 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st
func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}
- transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4}
+ transProtos := []stack.TransportProtocolFactory{
+ tcp.NewProtocol,
+ udp.NewProtocol,
+ icmp.NewProtocol4,
+ icmp.NewProtocol6,
+ }
s := netstack.Stack{stack.New(stack.Options{
NetworkProtocols: netProtos,
TransportProtocols: transProtos,
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index b157387ef..3fd28e516 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -250,36 +250,76 @@ func (c *containerMounter) configureOverlay(ctx context.Context, creds *auth.Cre
overlayOpts := *lowerOpts
overlayOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{}
- // Next mount upper and lower. Upper is a tmpfs mount to keep all
- // modifications inside the sandbox.
- upper, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, tmpfs.Name, &upperOpts)
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create upper layer for overlay, opts: %+v: %v", upperOpts, err)
- }
- cu := cleanup.Make(func() { upper.DecRef(ctx) })
- defer cu.Clean()
-
// All writes go to the upper layer, be paranoid and make lower readonly.
lowerOpts.ReadOnly = true
lower, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, lowerFSName, lowerOpts)
if err != nil {
return nil, nil, err
}
- cu.Add(func() { lower.DecRef(ctx) })
+ cu := cleanup.Make(func() { lower.DecRef(ctx) })
+ defer cu.Clean()
- // Propagate the lower layer's root's owner, group, and mode to the upper
- // layer's root for consistency with VFS1.
- upperRootVD := vfs.MakeVirtualDentry(upper, upper.Root())
+ // Determine the lower layer's root's type.
lowerRootVD := vfs.MakeVirtualDentry(lower, lower.Root())
stat, err := c.k.VFS().StatAt(ctx, creds, &vfs.PathOperation{
Root: lowerRootVD,
Start: lowerRootVD,
}, &vfs.StatOptions{
- Mask: linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE,
+ Mask: linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE | linux.STATX_TYPE,
})
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to stat lower layer's root: %v", err)
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 {
+ return nil, nil, fmt.Errorf("failed to get file type of lower layer's root")
+ }
+ rootType := stat.Mode & linux.S_IFMT
+ if rootType != linux.S_IFDIR && rootType != linux.S_IFREG {
+ return nil, nil, fmt.Errorf("lower layer's root has unsupported file type %v", rootType)
+ }
+
+ // Upper is a tmpfs mount to keep all modifications inside the sandbox.
+ upperOpts.GetFilesystemOptions.InternalData = tmpfs.FilesystemOpts{
+ RootFileType: uint16(rootType),
+ }
+ upper, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, tmpfs.Name, &upperOpts)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create upper layer for overlay, opts: %+v: %v", upperOpts, err)
+ }
+ cu.Add(func() { upper.DecRef(ctx) })
+
+ // If the overlay mount consists of a regular file, copy up its contents
+ // from the lower layer, since in the overlay the otherwise-empty upper
+ // layer file will take precedence.
+ upperRootVD := vfs.MakeVirtualDentry(upper, upper.Root())
+ if rootType == linux.S_IFREG {
+ lowerFD, err := c.k.VFS().OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: lowerRootVD,
+ Start: lowerRootVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to open lower layer root for copying: %v", err)
+ }
+ defer lowerFD.DecRef(ctx)
+ upperFD, err := c.k.VFS().OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: upperRootVD,
+ Start: upperRootVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY,
+ })
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to open upper layer root for copying: %v", err)
+ }
+ defer upperFD.DecRef(ctx)
+ if _, err := vfs.CopyRegularFileData(ctx, upperFD, lowerFD); err != nil {
+ return nil, nil, fmt.Errorf("failed to copy up overlay file: %v", err)
+ }
}
+
+ // Propagate the lower layer's root's owner, group, and mode to the upper
+ // layer's root for consistency with VFS1.
err = c.k.VFS().SetStatAt(ctx, creds, &vfs.PathOperation{
Root: upperRootVD,
Start: upperRootVD,
diff --git a/runsc/cli/main.go b/runsc/cli/main.go
index bca015db5..6c3bf4d21 100644
--- a/runsc/cli/main.go
+++ b/runsc/cli/main.go
@@ -22,6 +22,7 @@ import (
"io/ioutil"
"os"
"os/signal"
+ "runtime"
"syscall"
"time"
@@ -82,6 +83,7 @@ func Main(version string) {
subcommands.Register(new(cmd.Spec), "")
subcommands.Register(new(cmd.State), "")
subcommands.Register(new(cmd.Start), "")
+ subcommands.Register(new(cmd.Symbolize), "")
subcommands.Register(new(cmd.Wait), "")
// Register internal commands with the internal group name. This causes
@@ -207,6 +209,8 @@ func Main(version string) {
log.Infof("***************************")
log.Infof("Args: %s", os.Args)
log.Infof("Version %s", version)
+ log.Infof("GOOS: %s", runtime.GOOS)
+ log.Infof("GOARCH: %s", runtime.GOARCH)
log.Infof("PID: %d", os.Getpid())
log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid())
log.Infof("Configuration:")
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 2556f6d9e..19520d7ab 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -32,6 +32,7 @@ go_library(
"start.go",
"state.go",
"statefile.go",
+ "symbolize.go",
"syscalls.go",
"wait.go",
],
@@ -39,6 +40,7 @@ go_library(
"//runsc:__subpackages__",
],
deps = [
+ "//pkg/coverage",
"//pkg/log",
"//pkg/p9",
"//pkg/sentry/control",
diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go
index c0bc8f064..124198239 100644
--- a/runsc/cmd/checkpoint.go
+++ b/runsc/cmd/checkpoint.go
@@ -75,7 +75,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa
conf := args[0].(*config.Config)
waitStatus := args[1].(*syscall.WaitStatus)
- cont, err := container.LoadAndCheck(conf.RootDir, id)
+ cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index 609e8231c..b84142b0d 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -17,8 +17,10 @@ package cmd
import (
"context"
"os"
+ "os/signal"
"strconv"
"strings"
+ "sync"
"syscall"
"time"
@@ -43,6 +45,7 @@ type Debug struct {
strace string
logLevel string
logPackets string
+ delay time.Duration
duration time.Duration
ps bool
}
@@ -70,10 +73,11 @@ func (d *Debug) SetFlags(f *flag.FlagSet) {
f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.")
f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.")
f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.")
- f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles")
+ f.DurationVar(&d.delay, "delay", time.Hour, "amount of time to delay for collecting heap and goroutine profiles.")
+ f.DurationVar(&d.duration, "duration", time.Hour, "amount of time to wait for CPU and trace profiles.")
f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.")
f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox")
- f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all`)
+ f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all.`)
f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).")
f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.")
f.BoolVar(&d.ps, "ps", false, "lists processes")
@@ -90,8 +94,10 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
f.Usage()
return subcommands.ExitUsageError
}
+ id := f.Arg(0)
+
var err error
- c, err = container.LoadAndCheck(conf.RootDir, f.Arg(0))
+ c, err = container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
return Errorf("loading container %q: %v", f.Arg(0), err)
}
@@ -106,9 +112,10 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return Errorf("listing containers: %v", err)
}
for _, id := range ids {
- candidate, err := container.LoadAndCheck(conf.RootDir, id)
+ candidate, err := container.Load(conf.RootDir, id, container.LoadOpts{Exact: true, SkipCheck: true})
if err != nil {
- return Errorf("loading container %q: %v", id, err)
+ log.Warningf("Skipping container %q: %v", id, err)
+ continue
}
if candidate.SandboxPid() == d.pid {
c = candidate
@@ -120,11 +127,12 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
}
- if c.Sandbox == nil || !c.Sandbox.IsRunning() {
+ if !c.IsSandboxRunning() {
return Errorf("container sandbox is not running")
}
log.Infof("Found sandbox %q, PID: %d", c.Sandbox.ID, c.Sandbox.Pid)
+ // Perform synchronous actions.
if d.signal > 0 {
log.Infof("Sending signal %d to process: %d", d.signal, c.Sandbox.Pid)
if err := syscall.Kill(c.Sandbox.Pid, syscall.Signal(d.signal)); err != nil {
@@ -139,81 +147,6 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
log.Infof(" *** Stack dump ***\n%s", stacks)
}
- if d.profileHeap != "" {
- f, err := os.Create(d.profileHeap)
- if err != nil {
- return Errorf(err.Error())
- }
- defer f.Close()
-
- if err := c.Sandbox.HeapProfile(f); err != nil {
- return Errorf(err.Error())
- }
- log.Infof("Heap profile written to %q", d.profileHeap)
- }
- if d.profileBlock != "" {
- f, err := os.Create(d.profileBlock)
- if err != nil {
- return Errorf(err.Error())
- }
- defer f.Close()
-
- if err := c.Sandbox.BlockProfile(f); err != nil {
- return Errorf(err.Error())
- }
- log.Infof("Block profile written to %q", d.profileBlock)
- }
- if d.profileMutex != "" {
- f, err := os.Create(d.profileMutex)
- if err != nil {
- return Errorf(err.Error())
- }
- defer f.Close()
-
- if err := c.Sandbox.MutexProfile(f); err != nil {
- return Errorf(err.Error())
- }
- log.Infof("Mutex profile written to %q", d.profileMutex)
- }
-
- delay := false
- if d.profileCPU != "" {
- delay = true
- f, err := os.Create(d.profileCPU)
- if err != nil {
- return Errorf(err.Error())
- }
- defer func() {
- f.Close()
- if err := c.Sandbox.StopCPUProfile(); err != nil {
- Fatalf(err.Error())
- }
- log.Infof("CPU profile written to %q", d.profileCPU)
- }()
- if err := c.Sandbox.StartCPUProfile(f); err != nil {
- return Errorf(err.Error())
- }
- log.Infof("CPU profile started for %v, writing to %q", d.duration, d.profileCPU)
- }
- if d.trace != "" {
- delay = true
- f, err := os.Create(d.trace)
- if err != nil {
- return Errorf(err.Error())
- }
- defer func() {
- f.Close()
- if err := c.Sandbox.StopTrace(); err != nil {
- Fatalf(err.Error())
- }
- log.Infof("Trace written to %q", d.trace)
- }()
- if err := c.Sandbox.StartTrace(f); err != nil {
- return Errorf(err.Error())
- }
- log.Infof("Tracing started for %v, writing to %q", d.duration, d.trace)
- }
-
if d.strace != "" || len(d.logLevel) != 0 || len(d.logPackets) != 0 {
args := control.LoggingArgs{}
switch strings.ToLower(d.strace) {
@@ -282,8 +215,156 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
log.Infof(o)
}
- if delay {
- time.Sleep(d.duration)
+ // Open profiling files.
+ var (
+ heapFile *os.File
+ cpuFile *os.File
+ traceFile *os.File
+ blockFile *os.File
+ mutexFile *os.File
+ )
+ if d.profileHeap != "" {
+ f, err := os.OpenFile(d.profileHeap, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return Errorf("error opening heap profile output: %v", err)
+ }
+ defer f.Close()
+ heapFile = f
+ }
+ if d.profileCPU != "" {
+ f, err := os.OpenFile(d.profileCPU, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return Errorf("error opening cpu profile output: %v", err)
+ }
+ defer f.Close()
+ cpuFile = f
+ }
+ if d.trace != "" {
+ f, err := os.OpenFile(d.trace, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return Errorf("error opening trace profile output: %v", err)
+ }
+ traceFile = f
+ }
+ if d.profileBlock != "" {
+ f, err := os.OpenFile(d.profileBlock, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return Errorf("error opening blocking profile output: %v", err)
+ }
+ defer f.Close()
+ blockFile = f
+ }
+ if d.profileMutex != "" {
+ f, err := os.OpenFile(d.profileMutex, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return Errorf("error opening mutex profile output: %v", err)
+ }
+ defer f.Close()
+ mutexFile = f
+ }
+
+ // Collect profiles.
+ var (
+ wg sync.WaitGroup
+ heapErr error
+ cpuErr error
+ traceErr error
+ blockErr error
+ mutexErr error
+ )
+ if heapFile != nil {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ heapErr = c.Sandbox.HeapProfile(heapFile, d.delay)
+ }()
+ }
+ if cpuFile != nil {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ cpuErr = c.Sandbox.CPUProfile(cpuFile, d.duration)
+ }()
+ }
+ if traceFile != nil {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ traceErr = c.Sandbox.Trace(traceFile, d.duration)
+ }()
+ }
+ if blockFile != nil {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ blockErr = c.Sandbox.BlockProfile(blockFile, d.duration)
+ }()
+ }
+ if mutexFile != nil {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ mutexErr = c.Sandbox.MutexProfile(mutexFile, d.duration)
+ }()
+ }
+
+ // Before sleeping, allow us to catch signals and try to exit
+ // gracefully before just exiting. If we can't wait for wg, then
+ // we will not be able to read the errors below safely.
+ readyChan := make(chan struct{})
+ go func() {
+ defer close(readyChan)
+ wg.Wait()
+ }()
+ signals := make(chan os.Signal, 1)
+ signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
+ select {
+ case <-readyChan:
+ break // Safe to proceed.
+ case <-signals:
+ log.Infof("caught signal, waiting at most one more second.")
+ select {
+ case <-signals:
+ log.Infof("caught second signal, exiting immediately.")
+ os.Exit(1) // Not finished.
+ case <-time.After(time.Second):
+ log.Infof("timeout, exiting.")
+ os.Exit(1) // Not finished.
+ case <-readyChan:
+ break // Safe to proceed.
+ }
+ }
+
+ // Collect all errors.
+ errorCount := 0
+ if heapErr != nil {
+ errorCount++
+ log.Infof("error collecting heap profile: %v", heapErr)
+ os.Remove(heapFile.Name())
+ }
+ if cpuErr != nil {
+ errorCount++
+ log.Infof("error collecting cpu profile: %v", cpuErr)
+ os.Remove(cpuFile.Name())
+ }
+ if traceErr != nil {
+ errorCount++
+ log.Infof("error collecting trace profile: %v", traceErr)
+ os.Remove(traceFile.Name())
+ }
+ if blockErr != nil {
+ errorCount++
+ log.Infof("error collecting block profile: %v", blockErr)
+ os.Remove(blockFile.Name())
+ }
+ if mutexErr != nil {
+ errorCount++
+ log.Infof("error collecting mutex profile: %v", mutexErr)
+ os.Remove(mutexFile.Name())
+ }
+
+ if errorCount > 0 {
+ return subcommands.ExitFailure
}
return subcommands.ExitSuccess
diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go
index a25637265..a750be131 100644
--- a/runsc/cmd/delete.go
+++ b/runsc/cmd/delete.go
@@ -68,7 +68,7 @@ func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}
func (d *Delete) execute(ids []string, conf *config.Config) error {
for _, id := range ids {
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
if os.IsNotExist(err) && d.force {
log.Warningf("couldn't find container %q: %v", id, err)
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 640de4c47..8a8d9f752 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -81,7 +81,7 @@ func (c *Do) SetFlags(f *flag.FlagSet) {
// Execute implements subcommands.Command.Execute.
func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
if len(f.Args()) == 0 {
- c.Usage()
+ f.Usage()
return subcommands.ExitUsageError
}
diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go
index 3836b7b4e..75b0aac8d 100644
--- a/runsc/cmd/events.go
+++ b/runsc/cmd/events.go
@@ -74,7 +74,7 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading sandbox: %v", err)
}
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index eafd6285c..8558d34ae 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -112,7 +112,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
waitStatus := args[1].(*syscall.WaitStatus)
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading sandbox: %v", err)
}
diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go
index fe69e2a08..aecf0b7ab 100644
--- a/runsc/cmd/kill.go
+++ b/runsc/cmd/kill.go
@@ -69,7 +69,7 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("it is invalid to specify both --all and --pid")
}
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go
index 6907eb16a..9f9a47bd8 100644
--- a/runsc/cmd/list.go
+++ b/runsc/cmd/list.go
@@ -24,6 +24,7 @@ import (
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/flag"
@@ -71,7 +72,7 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
if l.quiet {
for _, id := range ids {
- fmt.Println(id)
+ fmt.Println(id.ContainerID)
}
return subcommands.ExitSuccess
}
@@ -79,9 +80,10 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Collect the containers.
var containers []*container.Container
for _, id := range ids {
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, id, container.LoadOpts{Exact: true})
if err != nil {
- Fatalf("loading container %q: %v", id, err)
+ log.Warningf("Skipping container %q: %v", id, err)
+ continue
}
containers = append(containers, c)
}
diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go
index fe7d4e257..15ef7b577 100644
--- a/runsc/cmd/pause.go
+++ b/runsc/cmd/pause.go
@@ -55,7 +55,7 @@ func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
id := f.Arg(0)
conf := args[0].(*config.Config)
- cont, err := container.LoadAndCheck(conf.RootDir, id)
+ cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go
index 18d7a1436..04e3e0bdd 100644
--- a/runsc/cmd/ps.go
+++ b/runsc/cmd/ps.go
@@ -60,7 +60,7 @@ func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{})
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading sandbox: %v", err)
}
diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go
index a00928204..856469252 100644
--- a/runsc/cmd/resume.go
+++ b/runsc/cmd/resume.go
@@ -56,7 +56,7 @@ func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}
id := f.Arg(0)
conf := args[0].(*config.Config)
- cont, err := container.LoadAndCheck(conf.RootDir, id)
+ cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
index f6499cc44..964a65064 100644
--- a/runsc/cmd/start.go
+++ b/runsc/cmd/start.go
@@ -55,7 +55,7 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go
index d8a70dd7f..1f7913d5a 100644
--- a/runsc/cmd/state.go
+++ b/runsc/cmd/state.go
@@ -57,7 +57,7 @@ func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/symbolize.go b/runsc/cmd/symbolize.go
new file mode 100644
index 000000000..fc0c69358
--- /dev/null
+++ b/runsc/cmd/symbolize.go
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "bufio"
+ "context"
+ "os"
+ "strconv"
+ "strings"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/coverage"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Symbolize implements subcommands.Command for the "symbolize" command.
+type Symbolize struct {
+ dumpAll bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Symbolize) Name() string {
+ return "symbolize"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Symbolize) Synopsis() string {
+ return "Convert synthetic instruction pointers from kcov into positions in the runsc source code. Only used when Go coverage is enabled."
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Symbolize) Usage() string {
+ return `symbolize - converts synthetic instruction pointers into positions in the runsc source code.
+
+This command takes instruction pointers from stdin and converts them into their
+corresponding file names and line/column numbers in the runsc source code. The
+inputs are not interpreted as actual addresses, but as synthetic values that are
+exposed through /sys/kernel/debug/kcov. One can extract coverage information
+from kcov and translate those values into locations in the source code by
+running symbolize on the same runsc binary.
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *Symbolize) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&c.dumpAll, "all", false, "dump information on all coverage blocks along with their synthetic PCs")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *Symbolize) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if f.NArg() != 0 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ if !coverage.KcovAvailable() {
+ return Errorf("symbolize can only be used when coverage is available.")
+ }
+ coverage.InitCoverageData()
+
+ if c.dumpAll {
+ coverage.WriteAllBlocks(os.Stdout)
+ return subcommands.ExitSuccess
+ }
+
+ scanner := bufio.NewScanner(os.Stdin)
+ for scanner.Scan() {
+ // Input is always base 16, but may or may not have a leading "0x".
+ str := strings.TrimPrefix(scanner.Text(), "0x")
+ pc, err := strconv.ParseUint(str, 16 /* base */, 64 /* bitSize */)
+ if err != nil {
+ return Errorf("Failed to symbolize \"%s\": %v", scanner.Text(), err)
+ }
+ if err := coverage.Symbolize(os.Stdout, pc); err != nil {
+ return Errorf("Failed to symbolize \"%s\": %v", scanner.Text(), err)
+ }
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go
index c1d6aeae2..5d55422c7 100644
--- a/runsc/cmd/wait.go
+++ b/runsc/cmd/wait.go
@@ -72,7 +72,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.LoadAndCheck(conf.RootDir, id)
+ c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index b02d8e2e1..e9fd7708f 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -131,7 +131,7 @@ type Config struct {
NumNetworkChannels int `flag:"num-network-channels"`
// Rootless allows the sandbox to be started with a user that is not root.
- // Defense is depth measures are weaker with rootless. Specifically, the
+ // Defense in depth measures are weaker in rootless mode. Specifically, the
// sandbox and Gofer process run as root inside a user namespace with root
// mapped to the caller's user.
Rootless bool `flag:"rootless"`
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 1900ecceb..8793c8916 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
package(licenses = ["notice"])
@@ -49,7 +49,7 @@ go_test(
"//test/cmd/test_app",
],
library = ":container",
- shard_count = 10,
+ shard_count = more_shards,
tags = [
"requires-kvm",
],
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 418a27beb..8b78660f7 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -128,125 +128,6 @@ type Container struct {
goferIsChild bool
}
-// loadSandbox loads all containers that belong to the sandbox with the given
-// ID.
-func loadSandbox(rootDir, id string) ([]*Container, error) {
- cids, err := List(rootDir)
- if err != nil {
- return nil, err
- }
-
- // Load the container metadata.
- var containers []*Container
- for _, cid := range cids {
- container, err := Load(rootDir, cid)
- if err != nil {
- // Container file may not exist if it raced with creation/deletion or
- // directory was left behind. Load provides a snapshot in time, so it's
- // fine to skip it.
- if os.IsNotExist(err) {
- continue
- }
- return nil, fmt.Errorf("loading container %q: %v", id, err)
- }
- if container.Sandbox.ID == id {
- containers = append(containers, container)
- }
- }
- return containers, nil
-}
-
-// Load loads a container with the given id from a metadata file. partialID may
-// be an abbreviation of the full container id, in which case Load loads the
-// container to which id unambiguously refers to. Returns ErrNotExist if
-// container doesn't exist.
-func Load(rootDir, partialID string) (*Container, error) {
- log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID)
- if err := validateID(partialID); err != nil {
- return nil, fmt.Errorf("invalid container id: %v", err)
- }
-
- id, err := findContainerID(rootDir, partialID)
- if err != nil {
- // Preserve error so that callers can distinguish 'not found' errors.
- return nil, err
- }
-
- state := StateFile{
- RootDir: rootDir,
- ID: id,
- }
- defer state.close()
-
- c := &Container{}
- if err := state.load(c); err != nil {
- if os.IsNotExist(err) {
- // Preserve error so that callers can distinguish 'not found' errors.
- return nil, err
- }
- return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err)
- }
- return c, nil
-}
-
-// LoadAndCheck is similar to Load(), but also checks if the container is still
-// running to get an error earlier to the caller.
-func LoadAndCheck(rootDir, partialID string) (*Container, error) {
- c, err := Load(rootDir, partialID)
- if err != nil {
- // Preserve error so that callers can distinguish 'not found' errors.
- return nil, err
- }
-
- // If the status is "Running" or "Created", check that the sandbox/container
- // is still running, setting it to Stopped if not.
- //
- // This is inherently racy.
- switch c.Status {
- case Created:
- if !c.isSandboxRunning() {
- // Sandbox no longer exists, so this container definitely does not exist.
- c.changeStatus(Stopped)
- }
- case Running:
- if err := c.SignalContainer(syscall.Signal(0), false); err != nil {
- c.changeStatus(Stopped)
- }
- }
-
- return c, nil
-}
-
-func findContainerID(rootDir, partialID string) (string, error) {
- // Check whether the id fully specifies an existing container.
- stateFile := buildStatePath(rootDir, partialID)
- if _, err := os.Stat(stateFile); err == nil {
- return partialID, nil
- }
-
- // Now see whether id could be an abbreviation of exactly 1 of the
- // container ids. If id is ambiguous (it could match more than 1
- // container), it is an error.
- ids, err := List(rootDir)
- if err != nil {
- return "", err
- }
- rv := ""
- for _, id := range ids {
- if strings.HasPrefix(id, partialID) {
- if rv != "" {
- return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id)
- }
- rv = id
- }
- }
- if rv == "" {
- return "", os.ErrNotExist
- }
- log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv)
- return rv, nil
-}
-
// Args is used to configure a new container.
type Args struct {
// ID is the container unique identifier.
@@ -291,6 +172,15 @@ func New(conf *config.Config, args Args) (*Container, error) {
return nil, fmt.Errorf("creating container root directory %q: %v", conf.RootDir, err)
}
+ sandboxID := args.ID
+ if !isRoot(args.Spec) {
+ var ok bool
+ sandboxID, ok = specutils.SandboxID(args.Spec)
+ if !ok {
+ return nil, fmt.Errorf("no sandbox ID found when creating container")
+ }
+ }
+
c := &Container{
ID: args.ID,
Spec: args.Spec,
@@ -301,7 +191,10 @@ func New(conf *config.Config, args Args) (*Container, error) {
Owner: os.Getenv("USER"),
Saver: StateFile{
RootDir: conf.RootDir,
- ID: args.ID,
+ ID: FullID{
+ SandboxID: sandboxID,
+ ContainerID: args.ID,
+ },
},
}
// The Cleanup object cleans up partially created containers when an error
@@ -316,10 +209,17 @@ func New(conf *config.Config, args Args) (*Container, error) {
}
defer c.Saver.unlock()
- // If the metadata annotations indicate that this container should be
- // started in an existing sandbox, we must do so. The metadata will
- // indicate the ID of the sandbox, which is the same as the ID of the
- // init container in the sandbox.
+ // If the metadata annotations indicate that this container should be started
+ // in an existing sandbox, we must do so. These are the possible metadata
+ // annotation states:
+ // 1. No annotations: it means that there is a single container and this
+ // container is obviously the root. Both container and sandbox share the
+ // ID.
+ // 2. Container type == sandbox: it means this is the root container
+ // starting the sandbox. Both container and sandbox share the same ID.
+ // 3. Container type == container: it means this is a subcontainer of an
+ // already started sandbox. In this case, container ID is different than
+ // the sandbox ID.
if isRoot(args.Spec) {
log.Debugf("Creating new sandbox for container, cid: %s", args.ID)
@@ -358,7 +258,7 @@ func New(conf *config.Config, args Args) (*Container, error) {
// Start a new sandbox for this container. Any errors after this point
// must destroy the container.
sandArgs := &sandbox.Args{
- ID: args.ID,
+ ID: sandboxID,
Spec: args.Spec,
BundleDir: args.BundleDir,
ConsoleSocket: args.ConsoleSocket,
@@ -379,22 +279,14 @@ func New(conf *config.Config, args Args) (*Container, error) {
return nil, err
}
} else {
- // This is sort of confusing. For a sandbox with a root
- // container and a child container in it, runsc sees:
- // * A container struct whose sandbox ID is equal to the
- // container ID. This is the root container that is tied to
- // the creation of the sandbox.
- // * A container struct whose sandbox ID is equal to the above
- // container/sandbox ID, but that has a different container
- // ID. This is the child container.
- sbid, ok := specutils.SandboxID(args.Spec)
- if !ok {
- return nil, fmt.Errorf("no sandbox ID found when creating container")
- }
- log.Debugf("Creating new container, cid: %s, sandbox: %s", c.ID, sbid)
+ log.Debugf("Creating new container, cid: %s, sandbox: %s", c.ID, sandboxID)
// Find the sandbox associated with this ID.
- sb, err := LoadAndCheck(conf.RootDir, sbid)
+ fullID := FullID{
+ SandboxID: sandboxID,
+ ContainerID: sandboxID,
+ }
+ sb, err := Load(conf.RootDir, fullID, LoadOpts{Exact: true})
if err != nil {
return nil, err
}
@@ -628,7 +520,7 @@ func (c *Container) Wait() (syscall.WaitStatus, error) {
// returns its WaitStatus.
func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) {
log.Debugf("Wait on process %d in sandbox, cid: %s", pid, c.Sandbox.ID)
- if !c.isSandboxRunning() {
+ if !c.IsSandboxRunning() {
return 0, fmt.Errorf("sandbox is not running")
}
return c.Sandbox.WaitPID(c.Sandbox.ID, pid)
@@ -638,7 +530,7 @@ func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) {
// its WaitStatus.
func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) {
log.Debugf("Wait on process %d in container, cid: %s", pid, c.ID)
- if !c.isSandboxRunning() {
+ if !c.IsSandboxRunning() {
return 0, fmt.Errorf("sandbox is not running")
}
return c.Sandbox.WaitPID(c.ID, pid)
@@ -658,7 +550,7 @@ func (c *Container) SignalContainer(sig syscall.Signal, all bool) error {
if err := c.requireStatus("signal", Running, Stopped); err != nil {
return err
}
- if !c.isSandboxRunning() {
+ if !c.IsSandboxRunning() {
return fmt.Errorf("sandbox is not running")
}
return c.Sandbox.SignalContainer(c.ID, sig, all)
@@ -670,7 +562,7 @@ func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error {
if err := c.requireStatus("signal a process inside", Running); err != nil {
return err
}
- if !c.isSandboxRunning() {
+ if !c.IsSandboxRunning() {
return fmt.Errorf("sandbox is not running")
}
return c.Sandbox.SignalProcess(c.ID, int32(pid), sig, false)
@@ -889,7 +781,7 @@ func (c *Container) waitForStopped() error {
defer cancel()
b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
op := func() error {
- if c.isSandboxRunning() {
+ if c.IsSandboxRunning() {
if err := c.SignalContainer(syscall.Signal(0), false); err == nil {
return fmt.Errorf("container is still running")
}
@@ -1091,7 +983,7 @@ func (c *Container) changeStatus(s Status) {
c.Status = s
}
-func (c *Container) isSandboxRunning() bool {
+func (c *Container) IsSandboxRunning() bool {
return c.Sandbox != nil && c.Sandbox.IsRunning()
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index fa99e403a..a92ae046d 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -364,7 +364,7 @@ func TestLifecycle(t *testing.T) {
defer c.Destroy()
// Load the container from disk and check the status.
- c, err = LoadAndCheck(rootDir, args.ID)
+ c, err = Load(rootDir, FullID{ContainerID: args.ID}, LoadOpts{})
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -377,7 +377,11 @@ func TestLifecycle(t *testing.T) {
if err != nil {
t.Fatalf("error listing containers: %v", err)
}
- if got, want := ids, []string{args.ID}; !reflect.DeepEqual(got, want) {
+ fullID := FullID{
+ SandboxID: args.ID,
+ ContainerID: args.ID,
+ }
+ if got, want := ids, []FullID{fullID}; !reflect.DeepEqual(got, want) {
t.Errorf("container list got %v, want %v", got, want)
}
@@ -387,7 +391,7 @@ func TestLifecycle(t *testing.T) {
}
// Load the container from disk and check the status.
- c, err = LoadAndCheck(rootDir, args.ID)
+ c, err = Load(rootDir, fullID, LoadOpts{Exact: true})
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -428,7 +432,7 @@ func TestLifecycle(t *testing.T) {
}
// Load the container from disk and check the status.
- c, err = LoadAndCheck(rootDir, args.ID)
+ c, err = Load(rootDir, fullID, LoadOpts{Exact: true})
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -451,7 +455,7 @@ func TestLifecycle(t *testing.T) {
}
// Loading the container by id should fail.
- if _, err = LoadAndCheck(rootDir, args.ID); err == nil {
+ if _, err = Load(rootDir, fullID, LoadOpts{Exact: true}); err == nil {
t.Errorf("expected loading destroyed container to fail, but it did not")
}
})
@@ -1738,7 +1742,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) {
cids[2]: cids[2],
}
for shortid, longid := range unambiguous {
- if _, err := LoadAndCheck(rootDir, shortid); err != nil {
+ if _, err := Load(rootDir, FullID{ContainerID: shortid}, LoadOpts{}); err != nil {
t.Errorf("%q should resolve to %q: %v", shortid, longid, err)
}
}
@@ -1749,7 +1753,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) {
"ba",
}
for _, shortid := range ambiguous {
- if s, err := LoadAndCheck(rootDir, shortid); err == nil {
+ if s, err := Load(rootDir, FullID{ContainerID: shortid}, LoadOpts{}); err == nil {
t.Errorf("%q should be ambiguous, but resolved to %q", shortid, s.ID)
}
}
@@ -2007,7 +2011,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) {
// Container is not thread safe, so load another instance to run in
// concurrently.
- startCont, err := LoadAndCheck(rootDir, args.ID)
+ startCont, err := Load(rootDir, FullID{ContainerID: args.ID}, LoadOpts{})
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -2332,6 +2336,42 @@ func TestTTYField(t *testing.T) {
}
}
+// Test that container can run even when there are corrupt state files in the
+// root directiry.
+func TestCreateWithCorruptedStateFile(t *testing.T) {
+ conf := testutil.TestConfig(t)
+ spec := testutil.NewSpecWithArgs("/bin/true")
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create corrupted state file.
+ corruptID := testutil.RandomContainerID()
+ corruptState := buildPath(conf.RootDir, FullID{SandboxID: corruptID, ContainerID: corruptID}, stateFileExtension)
+ if err := ioutil.WriteFile(corruptState, []byte("this{file(is;not[valid.json"), 0777); err != nil {
+ t.Fatalf("createCorruptStateFile(): %v", err)
+ }
+ defer os.Remove(corruptState)
+
+ if _, err := Load(conf.RootDir, FullID{ContainerID: corruptID}, LoadOpts{SkipCheck: true}); err == nil {
+ t.Fatalf("loading corrupted state file should have failed")
+ }
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ Attached: true,
+ }
+ if ws, err := Run(conf, args); err != nil {
+ t.Errorf("running container: %v", err)
+ } else if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Errorf("container failed, waitStatus: %v", ws)
+ }
+}
+
func execute(cont *Container, name string, arg ...string) (syscall.WaitStatus, error) {
args := &control.ExecArgs{
Filename: name,
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 45d4e6e6e..29db1b7e8 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -730,7 +730,7 @@ func TestMultiContainerKillAll(t *testing.T) {
// processes still running inside.
containers[1].SignalContainer(syscall.SIGKILL, false)
op := func() error {
- c, err := LoadAndCheck(conf.RootDir, ids[1])
+ c, err := Load(conf.RootDir, FullID{ContainerID: ids[1]}, LoadOpts{})
if err != nil {
return err
}
@@ -744,7 +744,7 @@ func TestMultiContainerKillAll(t *testing.T) {
}
}
- c, err := LoadAndCheck(conf.RootDir, ids[1])
+ c, err := Load(conf.RootDir, FullID{ContainerID: ids[1]}, LoadOpts{})
if err != nil {
t.Fatalf("failed to load child container %q: %v", c.ID, err)
}
@@ -867,7 +867,7 @@ func TestMultiContainerDestroyStarting(t *testing.T) {
// Container is not thread safe, so load another instance to run in
// concurrently.
- startCont, err := LoadAndCheck(rootDir, ids[i])
+ startCont, err := Load(rootDir, FullID{ContainerID: ids[i]}, LoadOpts{})
if err != nil {
t.Fatalf("error loading container: %v", err)
}
diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go
index 17a251530..dfbf1f2d3 100644
--- a/runsc/container/state_file.go
+++ b/runsc/container/state_file.go
@@ -20,58 +20,228 @@ import (
"io/ioutil"
"os"
"path/filepath"
+ "regexp"
+ "strings"
+ "syscall"
"github.com/gofrs/flock"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
)
-const stateFileExtension = ".state"
+const stateFileExtension = "state"
-// StateFile handles load from/save to container state safely from multiple
-// processes. It uses a lock file to provide synchronization between operations.
+// LoadOpts provides options for Load()ing a container.
+type LoadOpts struct {
+ // Exact tells whether the search should be exact. See Load() for more.
+ Exact bool
+
+ // SkipCheck tells Load() to skip checking if container is runnning.
+ SkipCheck bool
+}
+
+// Load loads a container with the given id from a metadata file. "id" may
+// be an abbreviation of the full container id in case LoadOpts.Exact if not
+// set. It also checks if the container is still running, in order to return
+// an error to the caller earlier. This check is skipped if LoadOpts.SkipCheck
+// is set.
//
-// The lock file is located at: "${s.RootDir}/${s.ID}.lock".
-// The state file is located at: "${s.RootDir}/${s.ID}.state".
-type StateFile struct {
- // RootDir is the directory containing the container metadata file.
- RootDir string `json:"rootDir"`
+// Returns ErrNotExist if no container is found. Returns error in case more than
+// one containers matching the ID prefix is found.
+func Load(rootDir string, id FullID, opts LoadOpts) (*Container, error) {
+ //log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID)
+ if !opts.Exact {
+ var err error
+ id, err = findContainerID(rootDir, id.ContainerID)
+ if err != nil {
+ // Preserve error so that callers can distinguish 'not found' errors.
+ return nil, err
+ }
+ }
- // ID is the container ID.
- ID string `json:"id"`
+ if err := id.validate(); err != nil {
+ return nil, fmt.Errorf("invalid container id: %v", err)
+ }
+ state := StateFile{
+ RootDir: rootDir,
+ ID: id,
+ }
+ defer state.close()
- //
- // Fields below this line are not saved in the state file and will not
- // be preserved across commands.
- //
+ c := &Container{}
+ if err := state.load(c); err != nil {
+ if os.IsNotExist(err) {
+ // Preserve error so that callers can distinguish 'not found' errors.
+ return nil, err
+ }
+ return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err)
+ }
- once sync.Once
- flock *flock.Flock
+ if !opts.SkipCheck {
+ // If the status is "Running" or "Created", check that the sandbox/container
+ // is still running, setting it to Stopped if not.
+ //
+ // This is inherently racy.
+ switch c.Status {
+ case Created:
+ if !c.IsSandboxRunning() {
+ // Sandbox no longer exists, so this container definitely does not exist.
+ c.changeStatus(Stopped)
+ }
+ case Running:
+ if err := c.SignalContainer(syscall.Signal(0), false); err != nil {
+ c.changeStatus(Stopped)
+ }
+ }
+ }
+
+ return c, nil
}
// List returns all container ids in the given root directory.
-func List(rootDir string) ([]string, error) {
+func List(rootDir string) ([]FullID, error) {
log.Debugf("List containers %q", rootDir)
- list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension))
+ return listMatch(rootDir, FullID{})
+}
+
+// listMatch returns all container ids that match the provided id.
+func listMatch(rootDir string, id FullID) ([]FullID, error) {
+ id.SandboxID += "*"
+ id.ContainerID += "*"
+ pattern := buildPath(rootDir, id, stateFileExtension)
+ list, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
- var out []string
+ var out []FullID
for _, path := range list {
- // Filter out files that do no belong to a container.
- fileName := filepath.Base(path)
- if len(fileName) < len(stateFileExtension) {
- panic(fmt.Sprintf("invalid file match %q", path))
- }
- // Remove the extension.
- cid := fileName[:len(fileName)-len(stateFileExtension)]
- if validateID(cid) == nil {
- out = append(out, cid)
+ id, err := parseFileName(filepath.Base(path))
+ if err == nil {
+ out = append(out, id)
}
}
return out, nil
}
+// loadSandbox loads all containers that belong to the sandbox with the given
+// ID.
+func loadSandbox(rootDir, id string) ([]*Container, error) {
+ cids, err := listMatch(rootDir, FullID{SandboxID: id})
+ if err != nil {
+ return nil, err
+ }
+
+ // Load the container metadata.
+ var containers []*Container
+ for _, cid := range cids {
+ container, err := Load(rootDir, cid, LoadOpts{Exact: true, SkipCheck: true})
+ if err != nil {
+ // Container file may not exist if it raced with creation/deletion or
+ // directory was left behind. Load provides a snapshot in time, so it's
+ // fine to skip it.
+ if os.IsNotExist(err) {
+ continue
+ }
+ return nil, fmt.Errorf("loading sandbox %q, failed to load container %q: %v", id, cid, err)
+ }
+ containers = append(containers, container)
+ }
+ return containers, nil
+}
+
+func findContainerID(rootDir, partialID string) (FullID, error) {
+ // Check whether the id fully specifies an existing container.
+ pattern := buildPath(rootDir, FullID{SandboxID: "*", ContainerID: partialID + "*"}, stateFileExtension)
+ list, err := filepath.Glob(pattern)
+ if err != nil {
+ return FullID{}, err
+ }
+ switch len(list) {
+ case 0:
+ return FullID{}, os.ErrNotExist
+ case 1:
+ return parseFileName(filepath.Base(list[0]))
+ }
+
+ // Now see whether id could be an abbreviation of exactly 1 of the
+ // container ids. If id is ambiguous (it could match more than 1
+ // container), it is an error.
+ ids, err := List(rootDir)
+ if err != nil {
+ return FullID{}, err
+ }
+ var rv *FullID
+ for _, id := range ids {
+ if strings.HasPrefix(id.ContainerID, partialID) {
+ if rv != nil {
+ return FullID{}, fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id)
+ }
+ rv = &id
+ }
+ }
+ if rv == nil {
+ return FullID{}, os.ErrNotExist
+ }
+ log.Debugf("abbreviated id %q resolves to full id %v", partialID, *rv)
+ return *rv, nil
+}
+
+func parseFileName(name string) (FullID, error) {
+ re := regexp.MustCompile(`([\w+-\.]+)_sandbox:([\w+-\.]+)\.` + stateFileExtension)
+ groups := re.FindStringSubmatch(name)
+ if len(groups) != 3 {
+ return FullID{}, fmt.Errorf("invalid state file name format: %q", name)
+ }
+ id := FullID{
+ SandboxID: groups[2],
+ ContainerID: groups[1],
+ }
+ if err := id.validate(); err != nil {
+ return FullID{}, fmt.Errorf("invalid state file name %q: %w", name, err)
+ }
+ return id, nil
+}
+
+// FullID combines sandbox and container ID to identify a container. Sandbox ID
+// is used to allow all containers for a given sandbox to be loaded by matching
+// sandbox ID in the file name.
+type FullID struct {
+ SandboxID string `json:"sandboxId"`
+ ContainerID string `json:"containerId"`
+}
+
+func (f *FullID) String() string {
+ return f.SandboxID + "/" + f.ContainerID
+}
+
+func (f *FullID) validate() error {
+ if err := validateID(f.SandboxID); err != nil {
+ return err
+ }
+ return validateID(f.ContainerID)
+}
+
+// StateFile handles load from/save to container state safely from multiple
+// processes. It uses a lock file to provide synchronization between operations.
+//
+// The lock file is located at: "${s.RootDir}/${containerd-id}_sand:{sandbox-id}.lock".
+// The state file is located at: "${s.RootDir}/${containerd-id}_sand:{sandbox-id}.state".
+type StateFile struct {
+ // RootDir is the directory containing the container metadata file.
+ RootDir string `json:"rootDir"`
+
+ // ID is the sandbox+container ID.
+ ID FullID `json:"id"`
+
+ //
+ // Fields below this line are not saved in the state file and will not
+ // be preserved across commands.
+ //
+
+ once sync.Once
+ flock *flock.Flock
+}
+
// lock globally locks all locking operations for the container.
func (s *StateFile) lock() error {
s.once.Do(func() {
@@ -157,18 +327,20 @@ func (s *StateFile) close() error {
return s.flock.Close()
}
-func buildStatePath(rootDir, id string) string {
- return filepath.Join(rootDir, id+stateFileExtension)
+func buildPath(rootDir string, id FullID, extension string) string {
+ // Note: "_" and ":" are not valid in IDs.
+ name := fmt.Sprintf("%s_sandbox:%s.%s", id.ContainerID, id.SandboxID, extension)
+ return filepath.Join(rootDir, name)
}
// statePath is the full path to the state file.
func (s *StateFile) statePath() string {
- return buildStatePath(s.RootDir, s.ID)
+ return buildPath(s.RootDir, s.ID, stateFileExtension)
}
// lockPath is the full path to the lock file.
func (s *StateFile) lockPath() string {
- return filepath.Join(s.RootDir, s.ID+".lock")
+ return buildPath(s.RootDir, s.ID, "lock")
}
// destroy deletes all state created by the stateFile. It may be called with the
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
index 96c57a426..c56e1d4d0 100644
--- a/runsc/fsgofer/BUILD
+++ b/runsc/fsgofer/BUILD
@@ -29,9 +29,12 @@ go_test(
srcs = ["fsgofer_test.go"],
library = ":fsgofer",
deps = [
+ "//pkg/fd",
"//pkg/log",
"//pkg/p9",
"//pkg/test/testutil",
+ "//runsc/specutils",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 0b628c8ce..3d94ffeb4 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -49,6 +49,21 @@ const (
allowedOpenFlags = unix.O_TRUNC
)
+var (
+ // Remember the process uid/gid to skip chown calls when file owner/group
+ // doesn't need to be changed.
+ processUID = p9.UID(os.Getuid())
+ processGID = p9.GID(os.Getgid())
+)
+
+// join is equivalent to path.Join() but skips path.Clean() which is expensive.
+func join(parent, child string) string {
+ if child == "." || child == ".." {
+ panic(fmt.Sprintf("invalid child path %q", child))
+ }
+ return parent + "/" + child
+}
+
// Config sets configuration options for each attach point.
type Config struct {
// ROMount is set to true if this is a readonly mount.
@@ -115,7 +130,7 @@ func (a *attachPoint) Attach() (p9.File, error) {
return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err)
}
- lf, err := newLocalFile(a, f, a.prefix, readable, stat)
+ lf, err := newLocalFile(a, f, a.prefix, readable, &stat)
if err != nil {
return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err)
}
@@ -124,7 +139,7 @@ func (a *attachPoint) Attach() (p9.File, error) {
}
// makeQID returns a unique QID for the given stat buffer.
-func (a *attachPoint) makeQID(stat unix.Stat_t) p9.QID {
+func (a *attachPoint) makeQID(stat *unix.Stat_t) p9.QID {
a.deviceMu.Lock()
defer a.deviceMu.Unlock()
@@ -245,7 +260,7 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) {
}
func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) {
- pathDebug := path.Join(parent.hostPath, name)
+ pathDebug := join(parent.hostPath, name)
f, readable, err := openAnyFile(pathDebug, func(mode int) (*fd.FD, error) {
return fd.OpenAt(parent.file, name, openFlags|mode, 0)
})
@@ -297,8 +312,8 @@ func openAnyFile(pathDebug string, fn func(mode int) (*fd.FD, error)) (*fd.FD, b
return nil, false, extractErrno(err)
}
-func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error {
- switch stat.Mode & unix.S_IFMT {
+func checkSupportedFileType(mode uint32, permitSocket bool) error {
+ switch mode & unix.S_IFMT {
case unix.S_IFREG, unix.S_IFDIR, unix.S_IFLNK:
return nil
@@ -313,8 +328,8 @@ func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error {
}
}
-func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat unix.Stat_t) (*localFile, error) {
- if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil {
+func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat *unix.Stat_t) (*localFile, error) {
+ if err := checkSupportedFileType(stat.Mode, a.conf.HostUDS); err != nil {
return nil, err
}
@@ -442,8 +457,10 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode,
})
defer cu.Clean()
- if err := fchown(child.FD(), uid, gid); err != nil {
- return nil, nil, p9.QID{}, 0, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(child.FD(), uid, gid); err != nil {
+ return nil, nil, p9.QID{}, 0, extractErrno(err)
+ }
}
stat, err := fstat(child.FD())
if err != nil {
@@ -452,11 +469,11 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode,
c := &localFile{
attachPoint: l.attachPoint,
- hostPath: path.Join(l.hostPath, name),
+ hostPath: join(l.hostPath, name),
file: child,
mode: mode,
fileType: unix.S_IFREG,
- qid: l.attachPoint.makeQID(stat),
+ qid: l.attachPoint.makeQID(&stat),
}
cu.Release()
@@ -488,8 +505,10 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
}
defer f.Close()
- if err := fchown(f.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(f.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
}
stat, err := fstat(f.FD())
if err != nil {
@@ -497,7 +516,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
}
cu.Release()
- return l.attachPoint.makeQID(stat), nil
+ return l.attachPoint.makeQID(&stat), nil
}
// Walk implements p9.File.
@@ -512,7 +531,7 @@ func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask,
if err != nil {
return nil, nil, p9.AttrMask{}, p9.Attr{}, err
}
- mask, attr := l.fillAttr(stat)
+ mask, attr := l.fillAttr(&stat)
return qids, file, mask, attr, nil
}
@@ -538,13 +557,13 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error)
file: newFile,
mode: invalidMode,
fileType: l.fileType,
- qid: l.attachPoint.makeQID(stat),
+ qid: l.attachPoint.makeQID(&stat),
controlReadable: readable,
}
return []p9.QID{c.qid}, c, stat, nil
}
- var qids []p9.QID
+ qids := make([]p9.QID, 0, len(names))
var lastStat unix.Stat_t
last := l
for _, name := range names {
@@ -560,7 +579,7 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error)
_ = f.Close()
return nil, nil, unix.Stat_t{}, extractErrno(err)
}
- c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat)
+ c, err := newLocalFile(last.attachPoint, f, path, readable, &lastStat)
if err != nil {
_ = f.Close()
return nil, nil, unix.Stat_t{}, extractErrno(err)
@@ -609,11 +628,11 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error)
if err != nil {
return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err)
}
- mask, attr := l.fillAttr(stat)
+ mask, attr := l.fillAttr(&stat)
return l.qid, mask, attr, nil
}
-func (l *localFile) fillAttr(stat unix.Stat_t) (p9.AttrMask, p9.Attr) {
+func (l *localFile) fillAttr(stat *unix.Stat_t) (p9.AttrMask, p9.Attr) {
attr := p9.Attr{
Mode: p9.FileMode(stat.Mode),
UID: p9.UID(stat.Uid),
@@ -881,8 +900,10 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.
}
defer f.Close()
- if err := fchown(f.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(f.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
}
stat, err := fstat(f.FD())
if err != nil {
@@ -890,7 +911,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.
}
cu.Release()
- return l.attachPoint.makeQID(stat), nil
+ return l.attachPoint.makeQID(&stat), nil
}
// Link implements p9.File.
@@ -938,8 +959,10 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid
}
defer child.Close()
- if err := fchown(child.FD(), uid, gid); err != nil {
- return p9.QID{}, extractErrno(err)
+ if uid != processUID || gid != processGID {
+ if err := fchown(child.FD(), uid, gid); err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
}
stat, err := fstat(child.FD())
if err != nil {
@@ -947,7 +970,7 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid
}
cu.Release()
- return l.attachPoint.makeQID(stat), nil
+ return l.attachPoint.makeQID(&stat), nil
}
// UnlinkAt implements p9.File.
@@ -1045,7 +1068,7 @@ func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64)
log.Warningf("Readdir is skipping file with failed stat %q, err: %v", l.hostPath, err)
continue
}
- qid := l.attachPoint.makeQID(stat)
+ qid := l.attachPoint.makeQID(&stat)
offset++
dirents = append(dirents, p9.Dirent{
QID: qid,
@@ -1139,7 +1162,7 @@ func (l *localFile) isOpen() bool {
// Renamed implements p9.Renamed.
func (l *localFile) Renamed(newDir p9.File, newName string) {
- l.hostPath = path.Join(newDir.(*localFile).hostPath, newName)
+ l.hostPath = join(newDir.(*localFile).hostPath, newName)
}
// extractErrno tries to determine the errno.
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index a84206686..c5daebe5e 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -23,10 +23,13 @@ import (
"path/filepath"
"testing"
+ "github.com/syndtr/gocapability/capability"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/specutils"
)
var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite}
@@ -197,10 +200,13 @@ func setup(fileType uint32) (string, string, error) {
switch fileType {
case unix.S_IFREG:
name = "file"
- _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
+ fd, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
if err != nil {
return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err)
}
+ if fd != nil {
+ fd.Close()
+ }
defer f.Close()
case unix.S_IFDIR:
name = "dir"
@@ -556,7 +562,28 @@ func TestROMountChecks(t *testing.T) {
func TestWalkNotFound(t *testing.T) {
runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) {
if _, _, err := s.file.Walk([]string{"nobody-here"}); err != unix.ENOENT {
- t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: unix.ENOENT", s, "nobody-here", err)
+ t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "nobody-here", err)
+ }
+ if _, _, err := s.file.Walk([]string{"nobody", "here"}); err != unix.ENOENT {
+ t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "nobody/here", err)
+ }
+ if !s.conf.ROMount {
+ if _, err := s.file.Mkdir("dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("MkDir(dir) failed, err: %v", err)
+ }
+ if _, _, err := s.file.Walk([]string{"dir", "nobody-here"}); err != unix.ENOENT {
+ t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "dir/nobody-here", err)
+ }
+ }
+ })
+}
+
+func TestWalkPanic(t *testing.T) {
+ runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) {
+ for _, name := range []string{".", ".."} {
+ assertPanic(t, func() {
+ s.file.Walk([]string{name})
+ })
}
})
}
@@ -574,6 +601,27 @@ func TestWalkDup(t *testing.T) {
})
}
+func TestWalkMultiple(t *testing.T) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ var names []string
+ var parent p9.File = s.file
+ for i := 0; i < 5; i++ {
+ name := fmt.Sprintf("dir%d", i)
+ names = append(names, name)
+
+ if _, err := parent.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
+ t.Fatalf("MkDir(%q) failed, err: %v", name, err)
+ }
+
+ var err error
+ _, parent, err = s.file.Walk(names)
+ if err != nil {
+ t.Errorf("Walk(%q): %v", name, err)
+ }
+ }
+ })
+}
+
func TestReaddir(t *testing.T) {
runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
name := "dir"
@@ -819,3 +867,168 @@ func TestMknod(t *testing.T) {
}
})
}
+
+func BenchmarkWalkOne(b *testing.B) {
+ path, name, err := setup(unix.S_IFDIR)
+ if err != nil {
+ b.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
+
+ a, err := NewAttachPoint(path, Config{})
+ if err != nil {
+ b.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ b.Fatalf("Attach failed, err: %v", err)
+ }
+ defer root.Close()
+
+ names := []string{name}
+ files := make([]p9.File, 0, 1000)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, file, err := root.Walk(names)
+ if err != nil {
+ b.Fatalf("Walk(%q): %v", name, err)
+ }
+ files = append(files, file)
+
+ // Avoid running out of FDs.
+ if len(files) == cap(files) {
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ files = files[:0]
+ b.StartTimer()
+ }
+ }
+
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+}
+
+func BenchmarkCreate(b *testing.B) {
+ path, _, err := setup(unix.S_IFDIR)
+ if err != nil {
+ b.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
+
+ a, err := NewAttachPoint(path, Config{})
+ if err != nil {
+ b.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ b.Fatalf("Attach failed, err: %v", err)
+ }
+ defer root.Close()
+
+ files := make([]p9.File, 0, 500)
+ fds := make([]*fd.FD, 0, 500)
+ uid := p9.UID(os.Getuid())
+ gid := p9.GID(os.Getgid())
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ name := fmt.Sprintf("same-%d", i)
+ fd, file, _, _, err := root.Create(name, p9.ReadOnly, 0777, uid, gid)
+ if err != nil {
+ b.Fatalf("Create(%q): %v", name, err)
+ }
+ files = append(files, file)
+ if fd != nil {
+ fds = append(fds, fd)
+ }
+
+ // Avoid running out of FDs.
+ if len(files) == cap(files) {
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ files = files[:0]
+ for _, fd := range fds {
+ fd.Close()
+ }
+ fds = fds[:0]
+ b.StartTimer()
+ }
+ }
+
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ for _, fd := range fds {
+ fd.Close()
+ }
+}
+
+func BenchmarkCreateDiffOwner(b *testing.B) {
+ if !specutils.HasCapabilities(capability.CAP_CHOWN) {
+ b.Skipf("Test requires CAP_CHOWN")
+ }
+
+ path, _, err := setup(unix.S_IFDIR)
+ if err != nil {
+ b.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
+
+ a, err := NewAttachPoint(path, Config{})
+ if err != nil {
+ b.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ b.Fatalf("Attach failed, err: %v", err)
+ }
+ defer root.Close()
+
+ files := make([]p9.File, 0, 500)
+ fds := make([]*fd.FD, 0, 500)
+ gid := p9.GID(os.Getgid())
+ const nobody = 65534
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ name := fmt.Sprintf("diff-%d", i)
+ fd, file, _, _, err := root.Create(name, p9.ReadOnly, 0777, nobody, gid)
+ if err != nil {
+ b.Fatalf("Create(%q): %v", name, err)
+ }
+ files = append(files, file)
+ if fd != nil {
+ fds = append(fds, fd)
+ }
+
+ // Avoid running out of FDs.
+ if len(files) == cap(files) {
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ files = files[:0]
+ for _, fd := range fds {
+ fd.Close()
+ }
+ fds = fds[:0]
+ b.StartTimer()
+ }
+ }
+
+ b.StopTimer()
+ for _, file := range files {
+ file.Close()
+ }
+ for _, fd := range fds {
+ fd.Close()
+ }
+}
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index d8112e7a2..9e429f7d5 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -279,8 +279,6 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
ll := syscall.SockaddrLinklayer{
Protocol: protocol,
Ifindex: iface.Index,
- Hatype: 0, // No ARP type.
- Pkttype: syscall.PACKET_OTHERHOST,
}
if err := syscall.Bind(fd, &ll); err != nil {
return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index c84ebcd8a..cfee9e63d 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -991,7 +991,7 @@ func (s *Sandbox) Stacks() (string, error) {
}
// HeapProfile writes a heap profile to the given file.
-func (s *Sandbox) HeapProfile(f *os.File) error {
+func (s *Sandbox) HeapProfile(f *os.File, delay time.Duration) error {
log.Debugf("Heap profile %q", s.ID)
conn, err := s.sandboxConnect()
if err != nil {
@@ -999,54 +999,31 @@ func (s *Sandbox) HeapProfile(f *os.File) error {
}
defer conn.Close()
- opts := control.ProfileOpts{
- FilePayload: urpc.FilePayload{
- Files: []*os.File{f},
- },
+ opts := control.HeapProfileOpts{
+ FilePayload: urpc.FilePayload{Files: []*os.File{f}},
+ Delay: delay,
}
- if err := conn.Call(boot.HeapProfile, &opts, nil); err != nil {
- return fmt.Errorf("getting sandbox %q heap profile: %v", s.ID, err)
- }
- return nil
+ return conn.Call(boot.HeapProfile, &opts, nil)
}
-// StartCPUProfile start CPU profile writing to the given file.
-func (s *Sandbox) StartCPUProfile(f *os.File) error {
- log.Debugf("CPU profile start %q", s.ID)
+// CPUProfile collects a CPU profile.
+func (s *Sandbox) CPUProfile(f *os.File, duration time.Duration) error {
+ log.Debugf("CPU profile %q", s.ID)
conn, err := s.sandboxConnect()
if err != nil {
return err
}
defer conn.Close()
- opts := control.ProfileOpts{
- FilePayload: urpc.FilePayload{
- Files: []*os.File{f},
- },
- }
- if err := conn.Call(boot.StartCPUProfile, &opts, nil); err != nil {
- return fmt.Errorf("starting sandbox %q CPU profile: %v", s.ID, err)
+ opts := control.CPUProfileOpts{
+ FilePayload: urpc.FilePayload{Files: []*os.File{f}},
+ Duration: duration,
}
- return nil
-}
-
-// StopCPUProfile stops a previously started CPU profile.
-func (s *Sandbox) StopCPUProfile() error {
- log.Debugf("CPU profile stop %q", s.ID)
- conn, err := s.sandboxConnect()
- if err != nil {
- return err
- }
- defer conn.Close()
-
- if err := conn.Call(boot.StopCPUProfile, nil, nil); err != nil {
- return fmt.Errorf("stopping sandbox %q CPU profile: %v", s.ID, err)
- }
- return nil
+ return conn.Call(boot.CPUProfile, &opts, nil)
}
// BlockProfile writes a block profile to the given file.
-func (s *Sandbox) BlockProfile(f *os.File) error {
+func (s *Sandbox) BlockProfile(f *os.File, duration time.Duration) error {
log.Debugf("Block profile %q", s.ID)
conn, err := s.sandboxConnect()
if err != nil {
@@ -1054,19 +1031,15 @@ func (s *Sandbox) BlockProfile(f *os.File) error {
}
defer conn.Close()
- opts := control.ProfileOpts{
- FilePayload: urpc.FilePayload{
- Files: []*os.File{f},
- },
+ opts := control.BlockProfileOpts{
+ FilePayload: urpc.FilePayload{Files: []*os.File{f}},
+ Duration: duration,
}
- if err := conn.Call(boot.BlockProfile, &opts, nil); err != nil {
- return fmt.Errorf("getting sandbox %q block profile: %v", s.ID, err)
- }
- return nil
+ return conn.Call(boot.BlockProfile, &opts, nil)
}
// MutexProfile writes a mutex profile to the given file.
-func (s *Sandbox) MutexProfile(f *os.File) error {
+func (s *Sandbox) MutexProfile(f *os.File, duration time.Duration) error {
log.Debugf("Mutex profile %q", s.ID)
conn, err := s.sandboxConnect()
if err != nil {
@@ -1074,50 +1047,27 @@ func (s *Sandbox) MutexProfile(f *os.File) error {
}
defer conn.Close()
- opts := control.ProfileOpts{
- FilePayload: urpc.FilePayload{
- Files: []*os.File{f},
- },
- }
- if err := conn.Call(boot.MutexProfile, &opts, nil); err != nil {
- return fmt.Errorf("getting sandbox %q mutex profile: %v", s.ID, err)
- }
- return nil
-}
-
-// StartTrace start trace writing to the given file.
-func (s *Sandbox) StartTrace(f *os.File) error {
- log.Debugf("Trace start %q", s.ID)
- conn, err := s.sandboxConnect()
- if err != nil {
- return err
- }
- defer conn.Close()
-
- opts := control.ProfileOpts{
- FilePayload: urpc.FilePayload{
- Files: []*os.File{f},
- },
- }
- if err := conn.Call(boot.StartTrace, &opts, nil); err != nil {
- return fmt.Errorf("starting sandbox %q trace: %v", s.ID, err)
+ opts := control.MutexProfileOpts{
+ FilePayload: urpc.FilePayload{Files: []*os.File{f}},
+ Duration: duration,
}
- return nil
+ return conn.Call(boot.MutexProfile, &opts, nil)
}
-// StopTrace stops a previously started trace.
-func (s *Sandbox) StopTrace() error {
- log.Debugf("Trace stop %q", s.ID)
+// Trace collects an execution trace.
+func (s *Sandbox) Trace(f *os.File, duration time.Duration) error {
+ log.Debugf("Trace %q", s.ID)
conn, err := s.sandboxConnect()
if err != nil {
return err
}
defer conn.Close()
- if err := conn.Call(boot.StopTrace, nil, nil); err != nil {
- return fmt.Errorf("stopping sandbox %q trace: %v", s.ID, err)
+ opts := control.TraceProfileOpts{
+ FilePayload: urpc.FilePayload{Files: []*os.File{f}},
+ Duration: duration,
}
- return nil
+ return conn.Call(boot.Trace, &opts, nil)
}
// ChangeLogging changes logging options.
diff --git a/test/benchmarks/BUILD b/test/benchmarks/BUILD
new file mode 100644
index 000000000..faf310676
--- /dev/null
+++ b/test/benchmarks/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "bzl_library")
+
+package(licenses = ["notice"])
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md
index d1bbabf6f..1bfb4a129 100644
--- a/test/benchmarks/README.md
+++ b/test/benchmarks/README.md
@@ -81,11 +81,8 @@ benchmarks.
In general, benchmarks should look like this:
```golang
-
-var h harness.Harness
-
func BenchmarkMyCoolOne(b *testing.B) {
- machine, err := h.GetMachine()
+ machine, err := harness.GetMachine()
// check err
defer machine.CleanUp()
@@ -95,14 +92,14 @@ func BenchmarkMyCoolOne(b *testing.B) {
b.ResetTimer()
- //Respect b.N.
+ // Respect b.N.
for i := 0; i < b.N; i++ {
out, err := container.Run(ctx, dockerutil.RunOpts{
Image: "benchmarks/my-cool-image",
Env: []string{"MY_VAR=awesome"},
other options...see dockerutil
}, "sh", "-c", "echo MY_VAR")
- //check err
+ // check err...
b.StopTimer()
// Do parsing and reporting outside of the timer.
@@ -114,16 +111,13 @@ func BenchmarkMyCoolOne(b *testing.B) {
}
func TestMain(m *testing.M) {
- h.Init()
+ harness.Init()
os.Exit(m.Run())
}
```
Some notes on the above:
-* The harness is initiated in the TestMain method and made global to test
- module. The harness will handle any presetup that needs to happen with
- flags, remote virtual machines (eventually), and other services.
* Respect `b.N` in that users of the benchmark may want to "run for an hour"
or something of the sort.
* Use the `b.ReportMetric()` method to report custom metrics.
diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD
index b4b55317b..697ab5837 100644
--- a/test/benchmarks/base/BUILD
+++ b/test/benchmarks/base/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -14,7 +15,7 @@ go_library(
],
)
-go_test(
+benchmark_test(
name = "startup_test",
size = "enormous",
srcs = ["startup_test.go"],
@@ -26,7 +27,7 @@ go_test(
],
)
-go_test(
+benchmark_test(
name = "size_test",
size = "enormous",
srcs = ["size_test.go"],
@@ -39,7 +40,7 @@ go_test(
],
)
-go_test(
+benchmark_test(
name = "sysbench_test",
size = "enormous",
srcs = ["sysbench_test.go"],
diff --git a/test/benchmarks/base/size_test.go b/test/benchmarks/base/size_test.go
index acc49cc7c..452926e5f 100644
--- a/test/benchmarks/base/size_test.go
+++ b/test/benchmarks/base/size_test.go
@@ -26,12 +26,10 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
-var testHarness harness.Harness
-
// BenchmarkSizeEmpty creates N empty containers and reads memory usage from
// /proc/meminfo.
func BenchmarkSizeEmpty(b *testing.B) {
- machine, err := testHarness.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
@@ -81,7 +79,7 @@ func BenchmarkSizeEmpty(b *testing.B) {
// BenchmarkSizeNginx starts N containers running Nginx, checks that they're
// serving, and checks memory used based on /proc/meminfo.
func BenchmarkSizeNginx(b *testing.B) {
- machine, err := testHarness.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
@@ -126,7 +124,7 @@ func BenchmarkSizeNginx(b *testing.B) {
// BenchmarkSizeNode starts N containers running a Node app, checks that
// they're serving, and checks memory used based on /proc/meminfo.
func BenchmarkSizeNode(b *testing.B) {
- machine, err := testHarness.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
@@ -178,6 +176,6 @@ func BenchmarkSizeNode(b *testing.B) {
// TestMain is the main method for package network.
func TestMain(m *testing.M) {
- testHarness.Init()
+ harness.Init()
os.Exit(m.Run())
}
diff --git a/test/benchmarks/base/startup_test.go b/test/benchmarks/base/startup_test.go
index 8ef9f99c4..05a43ad17 100644
--- a/test/benchmarks/base/startup_test.go
+++ b/test/benchmarks/base/startup_test.go
@@ -25,11 +25,9 @@ import (
"gvisor.dev/gvisor/test/benchmarks/harness"
)
-var testHarness harness.Harness
-
// BenchmarkStartEmpty times startup time for an empty container.
func BenchmarkStartupEmpty(b *testing.B) {
- machine, err := testHarness.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
@@ -53,7 +51,7 @@ func BenchmarkStartupEmpty(b *testing.B) {
// Time is measured from start until the first request is served.
func BenchmarkStartupNginx(b *testing.B) {
// The machine to hold Nginx and the Node Server.
- machine, err := testHarness.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
@@ -76,7 +74,7 @@ func BenchmarkStartupNginx(b *testing.B) {
// Time is measured from start until the first request is served.
// Note that the Node app connects to a Redis instance before serving.
func BenchmarkStartupNode(b *testing.B) {
- machine, err := testHarness.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
@@ -126,8 +124,8 @@ func runServerWorkload(ctx context.Context, b *testing.B, args base.ServerArgs)
return fmt.Errorf("failed to get ip from server: %v", err)
}
- harness.DebugLog(b, "Waiting for container to start.")
// Wait until the Client sees the server as up.
+ harness.DebugLog(b, "Waiting for container to start.")
if err := harness.WaitUntilServing(ctx, args.Machine, servingIP, args.Port); err != nil {
return fmt.Errorf("failed to wait for serving: %v", err)
}
@@ -141,6 +139,6 @@ func runServerWorkload(ctx context.Context, b *testing.B, args base.ServerArgs)
// TestMain is the main method for package network.
func TestMain(m *testing.M) {
- testHarness.Init()
+ harness.Init()
os.Exit(m.Run())
}
diff --git a/test/benchmarks/base/sysbench_test.go b/test/benchmarks/base/sysbench_test.go
index bbb797e14..80569687c 100644
--- a/test/benchmarks/base/sysbench_test.go
+++ b/test/benchmarks/base/sysbench_test.go
@@ -23,8 +23,6 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
-var testHarness harness.Harness
-
type testCase struct {
name string
test tools.Sysbench
@@ -32,42 +30,34 @@ type testCase struct {
// BenchmarSysbench runs sysbench on the runtime.
func BenchmarkSysbench(b *testing.B) {
-
testCases := []testCase{
testCase{
name: "CPU",
test: &tools.SysbenchCPU{
- Base: tools.SysbenchBase{
+ SysbenchBase: tools.SysbenchBase{
Threads: 1,
- Time: 5,
},
- MaxPrime: 50000,
},
},
testCase{
name: "Memory",
test: &tools.SysbenchMemory{
- Base: tools.SysbenchBase{
+ SysbenchBase: tools.SysbenchBase{
Threads: 1,
},
- BlockSize: "1M",
- TotalSize: "500G",
},
},
testCase{
name: "Mutex",
test: &tools.SysbenchMutex{
- Base: tools.SysbenchBase{
+ SysbenchBase: tools.SysbenchBase{
Threads: 8,
},
- Loops: 1,
- Locks: 10000000,
- Num: 4,
},
},
}
- machine, err := testHarness.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
@@ -87,12 +77,15 @@ func BenchmarkSysbench(b *testing.B) {
sysbench := machine.GetContainer(ctx, b)
defer sysbench.CleanUp(ctx)
+ cmd := tc.test.MakeCmd(b)
+ b.ResetTimer()
out, err := sysbench.Run(ctx, dockerutil.RunOpts{
Image: "benchmarks/sysbench",
- }, tc.test.MakeCmd()...)
+ }, cmd...)
if err != nil {
b.Fatalf("failed to run sysbench: %v: logs:%s", err, out)
}
+ b.StopTimer()
tc.test.Report(b, out)
})
}
diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD
index 93b380e8a..0b1743603 100644
--- a/test/benchmarks/database/BUILD
+++ b/test/benchmarks/database/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -6,19 +7,13 @@ go_library(
name = "database",
testonly = 1,
srcs = ["database.go"],
- deps = ["//test/benchmarks/harness"],
)
-go_test(
- name = "database_test",
+benchmark_test(
+ name = "redis_test",
size = "enormous",
srcs = ["redis_test.go"],
library = ":database",
- tags = [
- # Requires docker and runsc to be configured before test runs.
- "manual",
- "local",
- ],
visibility = ["//:sandbox"],
deps = [
"//pkg/test/dockerutil",
diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go
index 9eeb59f9a..c15ca661c 100644
--- a/test/benchmarks/database/database.go
+++ b/test/benchmarks/database/database.go
@@ -14,18 +14,3 @@
// Package database holds benchmarks around database applications.
package database
-
-import (
- "os"
- "testing"
-
- "gvisor.dev/gvisor/test/benchmarks/harness"
-)
-
-var h harness.Harness
-
-// TestMain is the main method for package database.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
-}
diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go
index 02e67154e..f3c4522ac 100644
--- a/test/benchmarks/database/redis_test.go
+++ b/test/benchmarks/database/redis_test.go
@@ -16,6 +16,7 @@ package database
import (
"context"
+ "os"
"testing"
"time"
@@ -49,13 +50,13 @@ var operations []string = []string{
// BenchmarkRedis runs redis-benchmark against a redis instance and reports
// data in queries per second. Each is reported by named operation (e.g. LPUSH).
func BenchmarkRedis(b *testing.B) {
- clientMachine, err := h.GetMachine()
+ clientMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
defer clientMachine.CleanUp()
- serverMachine, err := h.GetMachine()
+ serverMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
@@ -64,7 +65,6 @@ func BenchmarkRedis(b *testing.B) {
// Redis runs on port 6379 by default.
port := 6379
ctx := context.Background()
-
for _, operation := range operations {
param := tools.Parameter{
Name: "operation",
@@ -104,28 +104,26 @@ func BenchmarkRedis(b *testing.B) {
b.Fatalf("failed to start redis with: %v", err)
}
+ client := clientMachine.GetNativeContainer(ctx, b)
+ defer client.CleanUp(ctx)
+
redis := tools.Redis{
Operation: operation,
}
-
- // Reset profiles and timer to begin the measurement.
- server.RestartProfiles()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- client := clientMachine.GetNativeContainer(ctx, b)
- defer client.CleanUp(ctx)
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/redis",
- }, redis.MakeCmd(ip, serverPort)...)
- if err != nil {
- b.Fatalf("redis-benchmark failed with: %v", err)
- }
-
- // Stop time while we parse results.
- b.StopTimer()
- redis.Report(b, out)
- b.StartTimer()
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ }, redis.MakeCmd(ip, serverPort, b.N /*requests*/)...)
+ if err != nil {
+ b.Fatalf("redis-benchmark failed with: %v", err)
}
+ b.StopTimer()
+ redis.Report(b, out)
})
}
}
+
+func TestMain(m *testing.M) {
+ harness.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/defs.bzl b/test/benchmarks/defs.bzl
new file mode 100644
index 000000000..ef44b46e3
--- /dev/null
+++ b/test/benchmarks/defs.bzl
@@ -0,0 +1,14 @@
+"""Defines a rule for benchmark test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def benchmark_test(name, tags = [], **kwargs):
+ go_test(
+ name,
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ "local",
+ "manual",
+ ],
+ **kwargs
+ )
diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD
index 021fae38d..b4f967441 100644
--- a/test/benchmarks/fs/BUILD
+++ b/test/benchmarks/fs/BUILD
@@ -1,8 +1,8 @@
-load("//tools:defs.bzl", "go_test")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
-go_test(
+benchmark_test(
name = "bazel_test",
size = "enormous",
srcs = ["bazel_test.go"],
@@ -14,7 +14,7 @@ go_test(
],
)
-go_test(
+benchmark_test(
name = "fio_test",
size = "enormous",
srcs = ["fio_test.go"],
diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go
index 53ed3f9f2..8baeff0db 100644
--- a/test/benchmarks/fs/bazel_test.go
+++ b/test/benchmarks/fs/bazel_test.go
@@ -25,8 +25,6 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
-var h harness.Harness
-
// Note: CleanCache versions of this test require running with root permissions.
func BenchmarkBuildABSL(b *testing.B) {
runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...")
@@ -41,7 +39,7 @@ func BenchmarkBuildRunsc(b *testing.B) {
func runBuildBenchmark(b *testing.B, image, workdir, target string) {
b.Helper()
// Get a machine from the Harness on which to run.
- machine, err := h.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
@@ -61,10 +59,10 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
for _, bm := range benchmarks {
pageCache := tools.Parameter{
Name: "page_cache",
- Value: "clean",
+ Value: "dirty",
}
if bm.clearCache {
- pageCache.Value = "dirty"
+ pageCache.Value = "clean"
}
filesystem := tools.Parameter{
@@ -102,21 +100,20 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
prefix = "/tmp"
}
- // Restart profiles after the copy.
- container.RestartProfiles()
b.ResetTimer()
+ b.StopTimer()
+
// Drop Caches and bazel clean should happen inside the loop as we may use
// time options with b.N. (e.g. Run for an hour.)
for i := 0; i < b.N; i++ {
- b.StopTimer()
// Drop Caches for clear cache runs.
if bm.clearCache {
if err := harness.DropCaches(machine); err != nil {
b.Skipf("failed to drop caches: %v. You probably need root.", err)
}
}
- b.StartTimer()
+ b.StartTimer()
got, err := container.Exec(ctx, dockerutil.ExecOpts{
WorkDir: prefix + workdir,
}, "bazel", "build", "-c", "opt", target)
@@ -129,14 +126,15 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
if !strings.Contains(got, want) {
b.Fatalf("string %s not in: %s", want, got)
}
- // Clean bazel in case we use b.N.
- _, err = container.Exec(ctx, dockerutil.ExecOpts{
- WorkDir: prefix + workdir,
- }, "bazel", "clean")
- if err != nil {
- b.Fatalf("build failed with: %v", err)
+
+ // Clean bazel in the case we are doing another run.
+ if i < b.N-1 {
+ if _, err = container.Exec(ctx, dockerutil.ExecOpts{
+ WorkDir: prefix + workdir,
+ }, "bazel", "clean"); err != nil {
+ b.Fatalf("build failed with: %v", err)
+ }
}
- b.StartTimer()
}
})
}
@@ -144,6 +142,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
// TestMain is the main method for package fs.
func TestMain(m *testing.M) {
- h.Init()
+ harness.Init()
+ harness.SetFixedBenchmarks()
os.Exit(m.Run())
}
diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go
index 96340373c..83b8376a5 100644
--- a/test/benchmarks/fs/fio_test.go
+++ b/test/benchmarks/fs/fio_test.go
@@ -27,8 +27,6 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
-var h harness.Harness
-
// BenchmarkFio runs fio on the runtime under test. There are 4 basic test
// cases each run on a tmpfs mount and a bind mount. Fio requires root so that
// caches can be dropped.
@@ -36,33 +34,43 @@ func BenchmarkFio(b *testing.B) {
testCases := []tools.Fio{
tools.Fio{
Test: "write",
- Size: "5G",
- Blocksize: "1M",
- Iodepth: 4,
+ Size: b.N,
+ BlockSize: 4,
+ IODepth: 4,
+ },
+ tools.Fio{
+ Test: "write",
+ Size: b.N,
+ BlockSize: 1024,
+ IODepth: 4,
+ },
+ tools.Fio{
+ Test: "read",
+ Size: b.N,
+ BlockSize: 4,
+ IODepth: 4,
},
tools.Fio{
Test: "read",
- Size: "5G",
- Blocksize: "1M",
- Iodepth: 4,
+ Size: b.N,
+ BlockSize: 1024,
+ IODepth: 4,
},
tools.Fio{
Test: "randwrite",
- Size: "5G",
- Blocksize: "4K",
- Iodepth: 4,
- Time: 30,
+ Size: b.N,
+ BlockSize: 4,
+ IODepth: 4,
},
tools.Fio{
Test: "randread",
- Size: "5G",
- Blocksize: "4K",
- Iodepth: 4,
- Time: 30,
+ Size: b.N,
+ BlockSize: 4,
+ IODepth: 4,
},
}
- machine, err := h.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
@@ -74,11 +82,15 @@ func BenchmarkFio(b *testing.B) {
Name: "operation",
Value: tc.Test,
}
+ blockSize := tools.Parameter{
+ Name: "blockSize",
+ Value: fmt.Sprintf("%dK", tc.BlockSize),
+ }
filesystem := tools.Parameter{
Name: "filesystem",
Value: string(fsType),
}
- name, err := tools.ParametersToName(operation, filesystem)
+ name, err := tools.ParametersToName(operation, blockSize, filesystem)
if err != nil {
b.Fatalf("Failed to parser paramters: %v", err)
}
@@ -116,7 +128,7 @@ func BenchmarkFio(b *testing.B) {
// For reads, we need a file to read so make one inside the container.
if strings.Contains(tc.Test, "read") {
- fallocateCmd := fmt.Sprintf("fallocate -l %s %s", tc.Size, outfile)
+ fallocateCmd := fmt.Sprintf("fallocate -l %dK %s", tc.Size, outfile)
if out, err := container.Exec(ctx, dockerutil.ExecOpts{},
strings.Split(fallocateCmd, " ")...); err != nil {
b.Fatalf("failed to create readable file on mount: %v, %s", err, out)
@@ -128,22 +140,24 @@ func BenchmarkFio(b *testing.B) {
b.Skipf("failed to drop caches with %v. You probably need root.", err)
}
cmd := tc.MakeCmd(outfile)
- container.RestartProfiles()
+
b.ResetTimer()
+ b.StopTimer()
+
for i := 0; i < b.N; i++ {
+ if err := harness.DropCaches(machine); err != nil {
+ b.Fatalf("failed to drop caches: %v", err)
+ }
+
// Run fio.
+ b.StartTimer()
data, err := container.Exec(ctx, dockerutil.ExecOpts{}, cmd...)
if err != nil {
b.Fatalf("failed to run cmd %v: %v", cmd, err)
}
b.StopTimer()
+ b.SetBytes(1024 * 1024) // Bytes for go reporting (Size is in megabytes).
tc.Report(b, data)
- // If b.N is used (i.e. we run for an hour), we should drop caches
- // after each run.
- if err := harness.DropCaches(machine); err != nil {
- b.Fatalf("failed to drop caches: %v", err)
- }
- b.StartTimer()
}
})
}
@@ -185,6 +199,6 @@ func makeMount(machine harness.Machine, mountType mount.Type, target string) (mo
// TestMain is the main method for package fs.
func TestMain(m *testing.M) {
- h.Init()
+ harness.Init()
os.Exit(m.Run())
}
diff --git a/test/benchmarks/harness/harness.go b/test/benchmarks/harness/harness.go
index 5c9d0e01e..a853b7ba8 100644
--- a/test/benchmarks/harness/harness.go
+++ b/test/benchmarks/harness/harness.go
@@ -28,18 +28,14 @@ var (
debug = flag.Bool("debug", false, "turns on debug messages for individual benchmarks")
)
-// Harness is a handle for managing state in benchmark runs.
-type Harness struct {
-}
-
// Init performs any harness initilialization before runs.
-func (h *Harness) Init() error {
+func Init() error {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s -- --test.bench=<regex>\n", os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
- if flag.NFlag() == 0 || *help {
+ if *help {
flag.Usage()
os.Exit(0)
}
@@ -47,7 +43,15 @@ func (h *Harness) Init() error {
return nil
}
+// SetFixedBenchmarks causes all benchmarks to run once.
+//
+// This must be set if they cannot scale with N. Note that this uses 1ns
+// instead of 1x due to https://github.com/golang/go/issues/32051.
+func SetFixedBenchmarks() {
+ flag.Set("test.benchtime", "1ns")
+}
+
// GetMachine returns this run's implementation of machine.
-func (h *Harness) GetMachine() (Machine, error) {
+func GetMachine() (Machine, error) {
return &localMachine{}, nil
}
diff --git a/test/benchmarks/harness/machine.go b/test/benchmarks/harness/machine.go
index 88e5e841b..405b646e8 100644
--- a/test/benchmarks/harness/machine.go
+++ b/test/benchmarks/harness/machine.go
@@ -16,6 +16,7 @@ package harness
import (
"context"
+ "errors"
"net"
"os/exec"
@@ -66,14 +67,19 @@ func (l *localMachine) RunCommand(cmd string, args ...string) (string, error) {
// IPAddress implements Machine.IPAddress.
func (l *localMachine) IPAddress() (net.IP, error) {
- conn, err := net.Dial("udp", "8.8.8.8:80")
+ addrs, err := net.InterfaceAddrs()
if err != nil {
- return nil, err
+ return net.IP{}, err
}
- defer conn.Close()
-
- addr := conn.LocalAddr().(*net.UDPAddr)
- return addr.IP, nil
+ for _, a := range addrs {
+ if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
+ if ipnet.IP.To4() != nil {
+ return ipnet.IP, nil
+ }
+ }
+ }
+ // Unable to locate non-loopback address.
+ return nil, errors.New("no IPAddress available")
}
// CleanUp implements Machine.CleanUp and does nothing for localMachine.
diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD
index bb242d385..380783f0b 100644
--- a/test/benchmarks/media/BUILD
+++ b/test/benchmarks/media/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -6,12 +7,11 @@ go_library(
name = "media",
testonly = 1,
srcs = ["media.go"],
- deps = ["//test/benchmarks/harness"],
)
-go_test(
- name = "media_test",
- size = "large",
+benchmark_test(
+ name = "ffmpeg_test",
+ size = "enormous",
srcs = ["ffmpeg_test.go"],
library = ":media",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go
index 7822dfad7..1b99a319a 100644
--- a/test/benchmarks/media/ffmpeg_test.go
+++ b/test/benchmarks/media/ffmpeg_test.go
@@ -15,6 +15,7 @@ package media
import (
"context"
+ "os"
"strings"
"testing"
@@ -25,29 +26,36 @@ import (
// BenchmarkFfmpeg runs ffmpeg in a container and records runtime.
// BenchmarkFfmpeg should run as root to drop caches.
func BenchmarkFfmpeg(b *testing.B) {
- machine, err := h.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
defer machine.CleanUp()
ctx := context.Background()
- container := machine.GetContainer(ctx, b)
- defer container.CleanUp(ctx)
cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ")
b.ResetTimer()
+ b.StopTimer()
+
for i := 0; i < b.N; i++ {
- b.StopTimer()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
if err := harness.DropCaches(machine); err != nil {
b.Skipf("failed to drop caches: %v. You probably need root.", err)
}
- b.StartTimer()
+ b.StartTimer()
if _, err := container.Run(ctx, dockerutil.RunOpts{
Image: "benchmarks/ffmpeg",
}, cmd...); err != nil {
b.Fatalf("failed to run container: %v", err)
}
+ b.StopTimer()
}
}
+
+func TestMain(m *testing.M) {
+ harness.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go
index c7b35b758..ed7b24651 100644
--- a/test/benchmarks/media/media.go
+++ b/test/benchmarks/media/media.go
@@ -14,18 +14,3 @@
// Package media holds benchmarks around media processing applications.
package media
-
-import (
- "os"
- "testing"
-
- "gvisor.dev/gvisor/test/benchmarks/harness"
-)
-
-var h harness.Harness
-
-// TestMain is the main method for package media.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
-}
diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD
index 970f52706..285ec35d9 100644
--- a/test/benchmarks/ml/BUILD
+++ b/test/benchmarks/ml/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -6,12 +7,11 @@ go_library(
name = "ml",
testonly = 1,
srcs = ["ml.go"],
- deps = ["//test/benchmarks/harness"],
)
-go_test(
- name = "ml_test",
- size = "large",
+benchmark_test(
+ name = "tensorflow_test",
+ size = "enormous",
srcs = ["tensorflow_test.go"],
library = ":ml",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go
index 13282d7bb..d5fc5b7da 100644
--- a/test/benchmarks/ml/ml.go
+++ b/test/benchmarks/ml/ml.go
@@ -14,18 +14,3 @@
// Package ml holds benchmarks around machine learning performance.
package ml
-
-import (
- "os"
- "testing"
-
- "gvisor.dev/gvisor/test/benchmarks/harness"
-)
-
-var h harness.Harness
-
-// TestMain is the main method for package ml.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
-}
diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go
index f7746897d..b0e0c4720 100644
--- a/test/benchmarks/ml/tensorflow_test.go
+++ b/test/benchmarks/ml/tensorflow_test.go
@@ -15,6 +15,7 @@ package ml
import (
"context"
+ "os"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
@@ -35,7 +36,7 @@ func BenchmarkTensorflow(b *testing.B) {
"NeuralNetwork": "3_NeuralNetworks/neural_network.py",
}
- machine, err := h.GetMachine()
+ machine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
@@ -44,17 +45,19 @@ func BenchmarkTensorflow(b *testing.B) {
for name, workload := range workloads {
b.Run(name, func(b *testing.B) {
ctx := context.Background()
- container := machine.GetContainer(ctx, b)
- defer container.CleanUp(ctx)
b.ResetTimer()
+ b.StopTimer()
+
for i := 0; i < b.N; i++ {
- b.StopTimer()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
if err := harness.DropCaches(machine); err != nil {
b.Skipf("failed to drop caches: %v. You probably need root.", err)
}
- b.StartTimer()
+ // Run tensorflow.
+ b.StartTimer()
if out, err := container.Run(ctx, dockerutil.RunOpts{
Image: "benchmarks/tensorflow",
Env: []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"},
@@ -62,8 +65,14 @@ func BenchmarkTensorflow(b *testing.B) {
}, "python", workload); err != nil {
b.Fatalf("failed to run container: %v logs: %s", err, out)
}
+ b.StopTimer()
}
})
}
+}
+func TestMain(m *testing.M) {
+ harness.Init()
+ harness.SetFixedBenchmarks()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD
index 472b5c387..2741570f5 100644
--- a/test/benchmarks/network/BUILD
+++ b/test/benchmarks/network/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
+load("//test/benchmarks:defs.bzl", "benchmark_test")
package(licenses = ["notice"])
@@ -7,7 +8,6 @@ go_library(
testonly = 1,
srcs = [
"network.go",
- "static_server.go",
],
deps = [
"//pkg/test/dockerutil",
@@ -16,22 +16,74 @@ go_library(
],
)
-go_test(
- name = "network_test",
- size = "large",
+benchmark_test(
+ name = "iperf_test",
+ size = "enormous",
srcs = [
- "httpd_test.go",
"iperf_test.go",
- "nginx_test.go",
+ ],
+ library = ":network",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
+
+benchmark_test(
+ name = "node_test",
+ size = "enormous",
+ srcs = [
"node_test.go",
+ ],
+ library = ":network",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
+
+benchmark_test(
+ name = "ruby_test",
+ size = "enormous",
+ srcs = [
"ruby_test.go",
],
library = ":network",
- tags = [
- # Requires docker and runsc to be configured before test runs.
- "manual",
- "local",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
+
+benchmark_test(
+ name = "nginx_test",
+ size = "enormous",
+ srcs = [
+ "nginx_test.go",
+ ],
+ library = ":network",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
],
+)
+
+benchmark_test(
+ name = "httpd_test",
+ size = "enormous",
+ srcs = [
+ "httpd_test.go",
+ ],
+ library = ":network",
visibility = ["//:sandbox"],
deps = [
"//pkg/test/dockerutil",
diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go
index 8d7d5f750..629127250 100644
--- a/test/benchmarks/network/httpd_test.go
+++ b/test/benchmarks/network/httpd_test.go
@@ -14,10 +14,12 @@
package network
import (
+ "os"
"strconv"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
)
@@ -34,18 +36,20 @@ var httpdDocs = map[string]string{
// BenchmarkHttpd iterates over different sized payloads and concurrency, testing
// how well the runtime handles sending different payload sizes.
func BenchmarkHttpd(b *testing.B) {
- benchmarkHttpdDocSize(b, false /* reverse */)
+ benchmarkHttpdDocSize(b)
}
-// BenchmarkReverseHttpd iterates over different sized payloads, testing
-// how well the runtime handles receiving different payload sizes.
-func BenchmarkReverseHttpd(b *testing.B) {
- benchmarkHttpdDocSize(b, true /* reverse */)
+// BenchmarkContinuousHttpd runs specific benchmarks for continous jobs.
+// The runtime under test is the server serving a runc client.
+func BenchmarkContinuousHttpd(b *testing.B) {
+ sizes := []string{"10Kb", "100Kb", "1Mb"}
+ threads := []int{1, 25, 100, 1000}
+ benchmarkHttpdContinuous(b, threads, sizes)
}
// benchmarkHttpdDocSize iterates through all doc sizes, running subbenchmarks
// for each size.
-func benchmarkHttpdDocSize(b *testing.B, reverse bool) {
+func benchmarkHttpdDocSize(b *testing.B) {
b.Helper()
for size, filename := range httpdDocs {
concurrency := []int{1, 25, 50, 100, 1000}
@@ -64,18 +68,49 @@ func benchmarkHttpdDocSize(b *testing.B, reverse bool) {
}
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: c * b.N,
+ Requests: b.N,
Concurrency: c,
Doc: filename,
}
- runHttpd(b, hey, reverse)
+ runHttpd(b, hey)
+ })
+ }
+ }
+}
+
+// benchmarkHttpdContinuous iterates through given sizes and concurrencies.
+func benchmarkHttpdContinuous(b *testing.B, concurrency []int, sizes []string) {
+ for _, size := range sizes {
+ filename := httpdDocs[size]
+ for _, c := range concurrency {
+ fsize := tools.Parameter{
+ Name: "filesize",
+ Value: size,
+ }
+
+ threads := tools.Parameter{
+ Name: "concurrency",
+ Value: strconv.Itoa(c),
+ }
+
+ name, err := tools.ParametersToName(fsize, threads)
+ if err != nil {
+ b.Fatalf("Failed to parse parameters: %v", err)
+ }
+ b.Run(name, func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: b.N,
+ Concurrency: c,
+ Doc: filename,
+ }
+ runHttpd(b, hey)
})
}
}
}
// runHttpd configures the static serving methods to run httpd.
-func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) {
+func runHttpd(b *testing.B, hey *tools.Hey) {
// httpd runs on port 80.
port := 80
httpdRunOpts := dockerutil.RunOpts{
@@ -91,5 +126,10 @@ func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) {
},
}
httpdCmd := []string{"sh", "-c", "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"}
- runStaticServer(b, httpdRunOpts, httpdCmd, port, hey, reverse)
+ runStaticServer(b, httpdRunOpts, httpdCmd, port, hey)
+}
+
+func TestMain(m *testing.M) {
+ harness.Init()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go
index b8ab7dfb8..5e81149fe 100644
--- a/test/benchmarks/network/iperf_test.go
+++ b/test/benchmarks/network/iperf_test.go
@@ -15,6 +15,7 @@ package network
import (
"context"
+ "os"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
@@ -25,16 +26,16 @@ import (
func BenchmarkIperf(b *testing.B) {
iperf := tools.Iperf{
- Time: 10, // time in seconds to run client.
+ Num: b.N,
}
- clientMachine, err := h.GetMachine()
+ clientMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
defer clientMachine.CleanUp()
- serverMachine, err := h.GetMachine()
+ serverMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
}
@@ -91,23 +92,22 @@ func BenchmarkIperf(b *testing.B) {
if err := harness.WaitUntilServing(ctx, clientMachine, ip, servingPort); err != nil {
b.Fatalf("failed to wait for server: %v", err)
}
+
// Run the client.
b.ResetTimer()
-
- // Restart the server profiles. If the server isn't being profiled
- // this does nothing.
- server.RestartProfiles()
- for i := 0; i < b.N; i++ {
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/iperf",
- }, iperf.MakeCmd(ip, servingPort)...)
- if err != nil {
- b.Fatalf("failed to run client: %v", err)
- }
- b.StopTimer()
- iperf.Report(b, out)
- b.StartTimer()
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/iperf",
+ }, iperf.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("failed to run client: %v", err)
}
+ b.StopTimer()
+ iperf.Report(b, out)
})
}
}
+
+func TestMain(m *testing.M) {
+ harness.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/network/network.go b/test/benchmarks/network/network.go
index ce17ddb94..d61002cea 100644
--- a/test/benchmarks/network/network.go
+++ b/test/benchmarks/network/network.go
@@ -16,16 +16,65 @@
package network
import (
- "os"
+ "context"
"testing"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
)
-var h harness.Harness
+// runStaticServer runs static serving workloads (httpd, nginx).
+func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey) {
+ ctx := context.Background()
-// TestMain is the main method for package network.
-func TestMain(m *testing.M) {
- h.Init()
- os.Exit(m.Run())
+ // Get two machines: a client and server.
+ clientMachine, err := harness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := harness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // Make the containers.
+ client := clientMachine.GetNativeContainer(ctx, b)
+ defer client.CleanUp(ctx)
+ server := serverMachine.GetContainer(ctx, b)
+ defer server.CleanUp(ctx)
+
+ // Start the server.
+ if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil {
+ b.Fatalf("failed to start server: %v", err)
+ }
+
+ // Get its IP.
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to find server ip: %v", err)
+ }
+
+ // Get the published port.
+ servingPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to find server port %d: %v", port, err)
+ }
+
+ // Make sure the server is serving.
+ harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
+
+ // Run the client.
+ b.ResetTimer()
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/hey",
+ }, hey.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("run failed with: %v", err)
+ }
+ b.StopTimer()
+ hey.Report(b, out)
}
diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go
index 08565d0b2..74f3578fc 100644
--- a/test/benchmarks/network/nginx_test.go
+++ b/test/benchmarks/network/nginx_test.go
@@ -14,10 +14,12 @@
package network
import (
+ "os"
"strconv"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
)
@@ -34,19 +36,21 @@ var nginxDocs = map[string]string{
// BenchmarkNginxDocSize iterates over different sized payloads, testing how
// well the runtime handles sending different payload sizes.
func BenchmarkNginxDocSize(b *testing.B) {
- benchmarkNginxDocSize(b, false /* reverse */, true /* tmpfs */)
- benchmarkNginxDocSize(b, false /* reverse */, false /* tmpfs */)
+ benchmarkNginxDocSize(b, true /* tmpfs */)
+ benchmarkNginxDocSize(b, false /* tmpfs */)
}
-// BenchmarkReverseNginxDocSize iterates over different sized payloads, testing
-// how well the runtime handles receiving different payload sizes.
-func BenchmarkReverseNginxDocSize(b *testing.B) {
- benchmarkNginxDocSize(b, true /* reverse */, true /* tmpfs */)
+// BenchmarkContinuousNginx runs specific benchmarks for continous jobs.
+// The runtime under test is the sever serving a runc client.
+func BenchmarkContinuousNginx(b *testing.B) {
+ sizes := []string{"10Kb", "100Kb", "1Mb"}
+ threads := []int{1, 25, 100, 1000}
+ benchmarkNginxContinuous(b, threads, sizes)
}
// benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks
// for each size.
-func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) {
+func benchmarkNginxDocSize(b *testing.B, tmpfs bool) {
for size, filename := range nginxDocs {
concurrency := []int{1, 25, 50, 100, 1000}
for _, c := range concurrency {
@@ -71,21 +75,56 @@ func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) {
if err != nil {
b.Fatalf("Failed to parse parameters: %v", err)
}
+ b.Run(name, func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: b.N,
+ Concurrency: c,
+ Doc: filename,
+ }
+ runNginx(b, hey, tmpfs)
+ })
+ }
+ }
+}
+
+// benchmarkNginxContinuous iterates through given sizes and concurrencies on a tmpfs mount.
+func benchmarkNginxContinuous(b *testing.B, concurrency []int, sizes []string) {
+ for _, size := range sizes {
+ filename := nginxDocs[size]
+ for _, c := range concurrency {
+ fsize := tools.Parameter{
+ Name: "filesize",
+ Value: size,
+ }
+ threads := tools.Parameter{
+ Name: "concurrency",
+ Value: strconv.Itoa(c),
+ }
+
+ fs := tools.Parameter{
+ Name: "filesystem",
+ Value: "tmpfs",
+ }
+
+ name, err := tools.ParametersToName(fsize, threads, fs)
+ if err != nil {
+ b.Fatalf("Failed to parse parameters: %v", err)
+ }
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: c * b.N,
+ Requests: b.N,
Concurrency: c,
Doc: filename,
}
- runNginx(b, hey, reverse, tmpfs)
+ runNginx(b, hey, true /*tmpfs*/)
})
}
}
}
// runNginx configures the static serving methods to run httpd.
-func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) {
+func runNginx(b *testing.B, hey *tools.Hey, tmpfs bool) {
// nginx runs on port 80.
port := 80
nginxRunOpts := dockerutil.RunOpts{
@@ -99,5 +138,10 @@ func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) {
}
// Command copies nginxDocs to tmpfs serving directory and runs nginx.
- runStaticServer(b, nginxRunOpts, nginxCmd, port, hey, reverse)
+ runStaticServer(b, nginxRunOpts, nginxCmd, port, hey)
+}
+
+func TestMain(m *testing.M) {
+ harness.Init()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go
index 254538899..a1fc82f95 100644
--- a/test/benchmarks/network/node_test.go
+++ b/test/benchmarks/network/node_test.go
@@ -15,6 +15,7 @@ package network
import (
"context"
+ "os"
"strconv"
"testing"
"time"
@@ -41,7 +42,7 @@ func BenchmarkNode(b *testing.B) {
}
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: b.N * c, // Requests b.N requests per thread.
+ Requests: b.N,
Concurrency: c,
}
runNode(b, hey)
@@ -54,14 +55,14 @@ func runNode(b *testing.B, hey *tools.Hey) {
b.Helper()
// The machine to hold Redis and the Node Server.
- serverMachine, err := h.GetMachine()
+ serverMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
defer serverMachine.CleanUp()
// The machine to run 'hey'.
- clientMachine, err := h.GetMachine()
+ clientMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
@@ -116,10 +117,8 @@ func runNode(b *testing.B, hey *tools.Hey) {
heyCmd := hey.MakeCmd(servingIP, servingPort)
- nodeApp.RestartProfiles()
- b.ResetTimer()
-
// the client should run on Native.
+ b.ResetTimer()
client := clientMachine.GetNativeContainer(ctx, b)
out, err := client.Run(ctx, dockerutil.RunOpts{
Image: "benchmarks/hey",
@@ -129,7 +128,10 @@ func runNode(b *testing.B, hey *tools.Hey) {
}
// Stop the timer to parse the data and report stats.
- b.StopTimer()
hey.Report(b, out)
- b.StartTimer()
+}
+
+func TestMain(m *testing.M) {
+ harness.Init()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go
index 0174ff3f3..b7ec16e0a 100644
--- a/test/benchmarks/network/ruby_test.go
+++ b/test/benchmarks/network/ruby_test.go
@@ -16,6 +16,7 @@ package network
import (
"context"
"fmt"
+ "os"
"strconv"
"testing"
"time"
@@ -42,7 +43,7 @@ func BenchmarkRuby(b *testing.B) {
}
b.Run(name, func(b *testing.B) {
hey := &tools.Hey{
- Requests: b.N * c, // b.N requests per thread.
+ Requests: b.N,
Concurrency: c,
}
runRuby(b, hey)
@@ -52,16 +53,15 @@ func BenchmarkRuby(b *testing.B) {
// runRuby runs the test for a given # of requests and concurrency.
func runRuby(b *testing.B, hey *tools.Hey) {
- b.Helper()
// The machine to hold Redis and the Ruby Server.
- serverMachine, err := h.GetMachine()
+ serverMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
defer serverMachine.CleanUp()
// The machine to run 'hey'.
- clientMachine, err := h.GetMachine()
+ clientMachine, err := harness.GetMachine()
if err != nil {
b.Fatalf("failed to get machine with: %v", err)
}
@@ -123,10 +123,9 @@ func runRuby(b *testing.B, hey *tools.Hey) {
b.Fatalf("failed to wait until serving: %v", err)
}
heyCmd := hey.MakeCmd(servingIP, servingPort)
- rubyApp.RestartProfiles()
- b.ResetTimer()
// the client should run on Native.
+ b.ResetTimer()
client := clientMachine.GetNativeContainer(ctx, b)
defer client.CleanUp(ctx)
out, err := client.Run(ctx, dockerutil.RunOpts{
@@ -135,9 +134,11 @@ func runRuby(b *testing.B, hey *tools.Hey) {
if err != nil {
b.Fatalf("hey container failed: %v logs: %s", err, out)
}
-
- // Stop the timer to parse the data and report stats.
b.StopTimer()
hey.Report(b, out)
- b.StartTimer()
+}
+
+func TestMain(m *testing.M) {
+ harness.Init()
+ os.Exit(m.Run())
}
diff --git a/test/benchmarks/network/static_server.go b/test/benchmarks/network/static_server.go
deleted file mode 100644
index e747a1395..000000000
--- a/test/benchmarks/network/static_server.go
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package network
-
-import (
- "context"
- "testing"
-
- "gvisor.dev/gvisor/pkg/test/dockerutil"
- "gvisor.dev/gvisor/test/benchmarks/harness"
- "gvisor.dev/gvisor/test/benchmarks/tools"
-)
-
-// runStaticServer runs static serving workloads (httpd, nginx).
-func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) {
- ctx := context.Background()
-
- // Get two machines: a client and server.
- clientMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get machine: %v", err)
- }
- defer clientMachine.CleanUp()
-
- serverMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get machine: %v", err)
- }
- defer serverMachine.CleanUp()
-
- // Make the containers. 'reverse=true' specifies that the client should use the
- // runtime under test.
- var client, server *dockerutil.Container
- if reverse {
- client = clientMachine.GetContainer(ctx, b)
- server = serverMachine.GetNativeContainer(ctx, b)
- } else {
- client = clientMachine.GetNativeContainer(ctx, b)
- server = serverMachine.GetContainer(ctx, b)
- }
- defer client.CleanUp(ctx)
- defer server.CleanUp(ctx)
-
- // Start the server.
- if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil {
- b.Fatalf("failed to start server: %v", err)
- }
-
- // Get its IP.
- ip, err := serverMachine.IPAddress()
- if err != nil {
- b.Fatalf("failed to find server ip: %v", err)
- }
-
- // Get the published port.
- servingPort, err := server.FindPort(ctx, port)
- if err != nil {
- b.Fatalf("failed to find server port %d: %v", port, err)
- }
-
- // Make sure the server is serving.
- harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
- b.ResetTimer()
- server.RestartProfiles()
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/hey",
- }, hey.MakeCmd(ip, servingPort)...)
- if err != nil {
- b.Fatalf("run failed with: %v", err)
- }
-
- b.StopTimer()
- hey.Report(b, out)
- b.StartTimer()
-}
diff --git a/test/benchmarks/tools/fio.go b/test/benchmarks/tools/fio.go
index f5f60fa84..f6324c3ab 100644
--- a/test/benchmarks/tools/fio.go
+++ b/test/benchmarks/tools/fio.go
@@ -25,25 +25,20 @@ import (
// Fio makes 'fio' commands and parses their output.
type Fio struct {
Test string // test to run: read, write, randread, randwrite.
- Size string // total size to be read/written of format N[GMK] (e.g. 5G).
- Blocksize string // blocksize to be read/write of format N[GMK] (e.g. 4K).
- Iodepth int // iodepth for reads/writes.
- Time int // time to run the test in seconds, usually for rand(read/write).
+ Size int // total size to be read/written in megabytes.
+ BlockSize int // block size to be read/written in kilobytes.
+ IODepth int // I/O depth for reads/writes.
}
// MakeCmd makes a 'fio' command.
func (f *Fio) MakeCmd(filename string) []string {
cmd := []string{"fio", "--output-format=json", "--ioengine=sync"}
cmd = append(cmd, fmt.Sprintf("--name=%s", f.Test))
- cmd = append(cmd, fmt.Sprintf("--size=%s", f.Size))
- cmd = append(cmd, fmt.Sprintf("--blocksize=%s", f.Blocksize))
+ cmd = append(cmd, fmt.Sprintf("--size=%dM", f.Size))
+ cmd = append(cmd, fmt.Sprintf("--blocksize=%dK", f.BlockSize))
cmd = append(cmd, fmt.Sprintf("--filename=%s", filename))
- cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.Iodepth))
+ cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.IODepth))
cmd = append(cmd, fmt.Sprintf("--rw=%s", f.Test))
- if f.Time != 0 {
- cmd = append(cmd, "--time_based")
- cmd = append(cmd, fmt.Sprintf("--runtime=%d", f.Time))
- }
return cmd
}
diff --git a/test/benchmarks/tools/hey.go b/test/benchmarks/tools/hey.go
index b8cb938fe..de908feeb 100644
--- a/test/benchmarks/tools/hey.go
+++ b/test/benchmarks/tools/hey.go
@@ -19,7 +19,6 @@ import (
"net"
"regexp"
"strconv"
- "strings"
"testing"
)
@@ -32,8 +31,16 @@ type Hey struct {
// MakeCmd returns a 'hey' command.
func (h *Hey) MakeCmd(ip net.IP, port int) []string {
- return strings.Split(fmt.Sprintf("hey -n %d -c %d http://%s:%d/%s",
- h.Requests, h.Concurrency, ip, port, h.Doc), " ")
+ c := h.Concurrency
+ if c > h.Requests {
+ c = h.Requests
+ }
+ return []string{
+ "hey",
+ "-n", fmt.Sprintf("%d", h.Requests),
+ "-c", fmt.Sprintf("%d", c),
+ fmt.Sprintf("http://%s:%d/%s", ip.String(), port, h.Doc),
+ }
}
// Report parses output from 'hey' and reports metrics.
diff --git a/test/benchmarks/tools/iperf.go b/test/benchmarks/tools/iperf.go
index 5c4e7125b..abf296731 100644
--- a/test/benchmarks/tools/iperf.go
+++ b/test/benchmarks/tools/iperf.go
@@ -19,19 +19,27 @@ import (
"net"
"regexp"
"strconv"
- "strings"
"testing"
)
+const length = 64 * 1024
+
// Iperf is for the client side of `iperf`.
type Iperf struct {
- Time int
+ Num int
}
// MakeCmd returns a iperf client command.
func (i *Iperf) MakeCmd(ip net.IP, port int) []string {
- // iperf report in Kb realtime
- return strings.Split(fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", i.Time, ip, port), " ")
+ return []string{
+ "iperf",
+ "--format", "K", // Output in KBytes.
+ "--realtime", // Measured in realtime.
+ "--num", fmt.Sprintf("%d", i.Num),
+ "--length", fmt.Sprintf("%d", length),
+ "--client", ip.String(),
+ "--port", fmt.Sprintf("%d", port),
+ }
}
// Report parses output from iperf client and reports metrics.
@@ -42,6 +50,7 @@ func (i *Iperf) Report(b *testing.B, output string) {
if err != nil {
b.Fatalf("failed to parse bandwitdth from %s: %v", output, err)
}
+ b.SetBytes(length) // Measure Bytes/sec for b.N, although below is iperf output.
ReportCustomMetric(b, bW*1024, "bandwidth" /*metric name*/, "bytes_per_second" /*unit*/)
}
diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go
index e35886437..12fdbc7cc 100644
--- a/test/benchmarks/tools/redis.go
+++ b/test/benchmarks/tools/redis.go
@@ -19,7 +19,6 @@ import (
"net"
"regexp"
"strconv"
- "strings"
"testing"
)
@@ -29,17 +28,29 @@ type Redis struct {
}
// MakeCmd returns a redis-benchmark client command.
-func (r *Redis) MakeCmd(ip net.IP, port int) []string {
+func (r *Redis) MakeCmd(ip net.IP, port, requests int) []string {
// There is no -t PING_BULK for redis-benchmark, so adjust the command in that case.
// Note that "ping" will run both PING_INLINE and PING_BULK.
if r.Operation == "PING_BULK" {
- return strings.Split(
- fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, port), " ")
+ return []string{
+ "redis-benchmark",
+ "--csv",
+ "-t", "ping",
+ "-h", ip.String(),
+ "-p", fmt.Sprintf("%d", port),
+ "-n", fmt.Sprintf("%d", requests),
+ }
}
// runs redis-benchmark -t operation for 100K requests against server.
- return strings.Split(
- fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", r.Operation, ip, port), " ")
+ return []string{
+ "redis-benchmark",
+ "--csv",
+ "-t", r.Operation,
+ "-h", ip.String(),
+ "-p", fmt.Sprintf("%d", port),
+ "-n", fmt.Sprintf("%d", requests),
+ }
}
// Report parses output from redis-benchmark client and reports metrics.
diff --git a/test/benchmarks/tools/sysbench.go b/test/benchmarks/tools/sysbench.go
index 7ccacd8ff..350f8ec98 100644
--- a/test/benchmarks/tools/sysbench.go
+++ b/test/benchmarks/tools/sysbench.go
@@ -18,58 +18,48 @@ import (
"fmt"
"regexp"
"strconv"
- "strings"
"testing"
)
-var warmup = "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null &&"
-
// Sysbench represents a 'sysbench' command.
type Sysbench interface {
- MakeCmd() []string // Makes a sysbench command.
- flags() []string
- Report(*testing.B, string) // Reports results contained in string.
+ // MakeCmd constructs the relevant command line.
+ MakeCmd(*testing.B) []string
+
+ // Report reports relevant custom metrics.
+ Report(*testing.B, string)
}
// SysbenchBase is the top level struct for sysbench and holds top-level arguments
// for sysbench. See: 'sysbench --help'
type SysbenchBase struct {
- Threads int // number of Threads for the test.
- Time int // time limit for test in seconds.
+ // Threads is the number of threads for the test.
+ Threads int
}
// baseFlags returns top level flags.
-func (s *SysbenchBase) baseFlags() []string {
+func (s *SysbenchBase) baseFlags(b *testing.B, useEvents bool) []string {
var ret []string
if s.Threads > 0 {
ret = append(ret, fmt.Sprintf("--threads=%d", s.Threads))
}
- if s.Time > 0 {
- ret = append(ret, fmt.Sprintf("--time=%d", s.Time))
+ ret = append(ret, "--time=0") // Ensure other mechanism is used.
+ if useEvents {
+ ret = append(ret, fmt.Sprintf("--events=%d", b.N))
}
return ret
}
// SysbenchCPU is for 'sysbench [flags] cpu run' and holds CPU specific arguments.
type SysbenchCPU struct {
- Base SysbenchBase
- MaxPrime int // upper limit for primes generator [10000].
+ SysbenchBase
}
// MakeCmd makes commands for SysbenchCPU.
-func (s *SysbenchCPU) MakeCmd() []string {
- cmd := []string{warmup, "sysbench"}
- cmd = append(cmd, s.flags()...)
- cmd = append(cmd, "cpu run")
- return []string{"sh", "-c", strings.Join(cmd, " ")}
-}
-
-// flags makes flags for SysbenchCPU cmds.
-func (s *SysbenchCPU) flags() []string {
- cmd := s.Base.baseFlags()
- if s.MaxPrime > 0 {
- return append(cmd, fmt.Sprintf("--cpu-max-prime=%d", s.MaxPrime))
- }
+func (s *SysbenchCPU) MakeCmd(b *testing.B) []string {
+ cmd := []string{"sysbench"}
+ cmd = append(cmd, s.baseFlags(b, true /* useEvents */)...)
+ cmd = append(cmd, "cpu", "run")
return cmd
}
@@ -96,9 +86,8 @@ func (s *SysbenchCPU) parseEvents(data string) (float64, error) {
// SysbenchMemory is for 'sysbench [FLAGS] memory run' and holds Memory specific arguments.
type SysbenchMemory struct {
- Base SysbenchBase
- BlockSize string // size of test memory block [1K].
- TotalSize string // size of data to transfer [100G].
+ SysbenchBase
+ BlockSize int // size of test memory block in megabytes [1].
Scope string // memory access scope {global, local} [global].
HugeTLB bool // allocate memory from HugeTLB [off].
OperationType string // type of memory ops {read, write, none} [write].
@@ -106,21 +95,18 @@ type SysbenchMemory struct {
}
// MakeCmd makes commands for SysbenchMemory.
-func (s *SysbenchMemory) MakeCmd() []string {
- cmd := []string{warmup, "sysbench"}
- cmd = append(cmd, s.flags()...)
- cmd = append(cmd, "memory run")
- return []string{"sh", "-c", strings.Join(cmd, " ")}
+func (s *SysbenchMemory) MakeCmd(b *testing.B) []string {
+ cmd := []string{"sysbench"}
+ cmd = append(cmd, s.flags(b)...)
+ cmd = append(cmd, "memory", "run")
+ return cmd
}
// flags makes flags for SysbenchMemory cmds.
-func (s *SysbenchMemory) flags() []string {
- cmd := s.Base.baseFlags()
- if s.BlockSize != "" {
- cmd = append(cmd, fmt.Sprintf("--memory-block-size=%s", s.BlockSize))
- }
- if s.TotalSize != "" {
- cmd = append(cmd, fmt.Sprintf("--memory-total-size=%s", s.TotalSize))
+func (s *SysbenchMemory) flags(b *testing.B) []string {
+ cmd := s.baseFlags(b, false /* useEvents */)
+ if s.BlockSize != 0 {
+ cmd = append(cmd, fmt.Sprintf("--memory-block-size=%dM", s.BlockSize))
}
if s.Scope != "" {
cmd = append(cmd, fmt.Sprintf("--memory-scope=%s", s.Scope))
@@ -134,6 +120,10 @@ func (s *SysbenchMemory) flags() []string {
if s.AccessMode != "" {
cmd = append(cmd, fmt.Sprintf("--memory-access-mode=%s", s.AccessMode))
}
+ // Sysbench ignores events for memory tests, and uses the total
+ // size parameter to determine when the test is done. We scale
+ // with this instead.
+ cmd = append(cmd, fmt.Sprintf("--memory-total-size=%dG", b.N))
return cmd
}
@@ -147,7 +137,7 @@ func (s *SysbenchMemory) Report(b *testing.B, output string) {
ReportCustomMetric(b, result, "memory_operations" /*metric name*/, "ops_per_second" /*unit*/)
}
-var memoryOperationsRE = regexp.MustCompile(`Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)`)
+var memoryOperationsRE = regexp.MustCompile(`Total\s+operations:\s+\d+\s+\((\s*\d+\.\d+\s*)\s+per\s+second\)`)
// parseOperations parses memory operations per second form sysbench memory ouput.
func (s *SysbenchMemory) parseOperations(data string) (float64, error) {
@@ -160,33 +150,34 @@ func (s *SysbenchMemory) parseOperations(data string) (float64, error) {
// SysbenchMutex is for 'sysbench [FLAGS] mutex run' and holds Mutex specific arguments.
type SysbenchMutex struct {
- Base SysbenchBase
+ SysbenchBase
Num int // total size of mutex array [4096].
- Locks int // number of mutex locks per thread [50K].
- Loops int // number of loops to do outside mutex lock [10K].
+ Loops int // number of loops to do outside mutex lock [10000].
}
// MakeCmd makes commands for SysbenchMutex.
-func (s *SysbenchMutex) MakeCmd() []string {
- cmd := []string{warmup, "sysbench"}
- cmd = append(cmd, s.flags()...)
- cmd = append(cmd, "mutex run")
- return []string{"sh", "-c", strings.Join(cmd, " ")}
+func (s *SysbenchMutex) MakeCmd(b *testing.B) []string {
+ cmd := []string{"sysbench"}
+ cmd = append(cmd, s.flags(b)...)
+ cmd = append(cmd, "mutex", "run")
+ return cmd
}
// flags makes flags for SysbenchMutex commands.
-func (s *SysbenchMutex) flags() []string {
+func (s *SysbenchMutex) flags(b *testing.B) []string {
var cmd []string
- cmd = append(cmd, s.Base.baseFlags()...)
+ cmd = append(cmd, s.baseFlags(b, false /* useEvents */)...)
if s.Num > 0 {
cmd = append(cmd, fmt.Sprintf("--mutex-num=%d", s.Num))
}
- if s.Locks > 0 {
- cmd = append(cmd, fmt.Sprintf("--mutex-locks=%d", s.Locks))
- }
if s.Loops > 0 {
cmd = append(cmd, fmt.Sprintf("--mutex-loops=%d", s.Loops))
}
+ // Sysbench does not respect --events for mutex tests. From [1]:
+ // "Here --time or --events are completely ignored. Sysbench always
+ // runs one event per thread."
+ // [1] https://tomfern.com/posts/sysbench-guide-1
+ cmd = append(cmd, fmt.Sprintf("--mutex-locks=%d", b.N))
return cmd
}
diff --git a/test/cmd/test_app/fds.go b/test/cmd/test_app/fds.go
index a7658eefd..d4354f0d3 100644
--- a/test/cmd/test_app/fds.go
+++ b/test/cmd/test_app/fds.go
@@ -16,6 +16,7 @@ package main
import (
"context"
+ "io"
"io/ioutil"
"log"
"os"
@@ -168,8 +169,8 @@ func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...int
file := os.NewFile(uintptr(fd), "received file")
defer file.Close()
- if _, err := file.Seek(0, os.SEEK_SET); err != nil {
- log.Fatalf("Seek(0, 0) failed: %v", err)
+ if _, err := file.Seek(0, io.SeekStart); err != nil {
+ log.Fatalf("Error from seek(0, 0): %v", err)
}
got, err := ioutil.ReadAll(file)
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 8425abecb..d07ed6ba5 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -260,12 +260,10 @@ func TestMemLimit(t *testing.T) {
d := dockerutil.MakeContainer(ctx, t)
defer d.CleanUp(ctx)
- // N.B. Because the size of the memory file may grow in large chunks,
- // there is a minimum threshold of 1GB for the MemTotal figure.
- allocMemory := 1024 * 1024 // In kb.
+ allocMemoryKb := 50 * 1024
out, err := d.Run(ctx, dockerutil.RunOpts{
Image: "basic/alpine",
- Memory: allocMemory * 1024, // In bytes.
+ Memory: allocMemoryKb * 1024, // In bytes.
}, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'")
if err != nil {
t.Fatalf("docker run failed: %v", err)
@@ -285,7 +283,7 @@ func TestMemLimit(t *testing.T) {
if err != nil {
t.Fatalf("failed to parse %q: %v", out, err)
}
- if want := uint64(allocMemory); got != want {
+ if want := uint64(allocMemoryKb); got != want {
t.Errorf("MemTotal got: %d, want: %d", got, want)
}
}
@@ -494,6 +492,55 @@ func TestLink(t *testing.T) {
}
}
+// This test ensures we can run ping without errors.
+func TestPing4Loopback(t *testing.T) {
+ if testutil.IsRunningWithHostNet() {
+ // TODO(gvisor.dev/issue/5011): support ICMP sockets in hostnet and enable
+ // this test.
+ t.Skip("hostnet only supports TCP/UDP sockets, so ping is not supported.")
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ping4test",
+ }, "/root/ping4.sh"); err != nil {
+ t.Fatalf("docker run failed: %s", err)
+ } else if got != "" {
+ t.Errorf("test failed:\n%s", got)
+ }
+}
+
+// This test ensures we can enable ipv6 on loopback and run ping6 without
+// errors.
+func TestPing6Loopback(t *testing.T) {
+ if testutil.IsRunningWithHostNet() {
+ // TODO(gvisor.dev/issue/5011): support ICMP sockets in hostnet and enable
+ // this test.
+ t.Skip("hostnet only supports TCP/UDP sockets, so ping6 is not supported.")
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ping6test",
+ // The CAP_NET_ADMIN capability is required to use the `ip` utility, which
+ // we use to enable ipv6 on loopback.
+ //
+ // By default, ipv6 loopback is not enabled by runsc, because docker does
+ // not assign an ipv6 address to the test container.
+ CapAdd: []string{"NET_ADMIN"},
+ }, "/root/ping6.sh"); err != nil {
+ t.Fatalf("docker run failed: %s", err)
+ } else if got != "" {
+ t.Errorf("test failed:\n%s", got)
+ }
+}
+
func TestMain(m *testing.M) {
dockerutil.EnsureSupportedDockerVersion()
flag.Parse()
diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go
index 70bbe5121..84564cdaa 100644
--- a/test/e2e/regression_test.go
+++ b/test/e2e/regression_test.go
@@ -35,7 +35,7 @@ func TestBindOverlay(t *testing.T) {
// Run the container.
got, err := d.Run(ctx, dockerutil.RunOpts{
Image: "basic/ubuntu",
- }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p")
+ }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p")
if err != nil {
t.Fatalf("docker run failed: %v", err)
}
diff --git a/test/fuse/BUILD b/test/fuse/BUILD
index 8e31fdd41..74500ec84 100644
--- a/test/fuse/BUILD
+++ b/test/fuse/BUILD
@@ -71,3 +71,8 @@ syscall_test(
fuse = "True",
test = "//test/fuse/linux:setstat_test",
)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:mount_test",
+)
diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD
index 7673252ec..2f745bd47 100644
--- a/test/fuse/linux/BUILD
+++ b/test/fuse/linux/BUILD
@@ -228,3 +228,16 @@ cc_binary(
"//test/util:test_util",
],
)
+
+cc_binary(
+ name = "mount_test",
+ testonly = 1,
+ srcs = ["mount_test.cc"],
+ deps = [
+ gtest,
+ "//test/util:mount_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/fuse/linux/mount_test.cc b/test/fuse/linux/mount_test.cc
new file mode 100644
index 000000000..8a5478116
--- /dev/null
+++ b/test/fuse/linux/mount_test.cc
@@ -0,0 +1,83 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/mount.h>
+
+#include "gtest/gtest.h"
+#include "test/util/mount_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(FuseMount, Success) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY));
+ std::string mopts = absl::StrCat("fd=", std::to_string(fd.get()));
+
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ const auto mount =
+ ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0));
+}
+
+TEST(FuseMount, FDNotParsable) {
+ int devfd;
+ EXPECT_THAT(devfd = open("/dev/fuse", O_RDWR), SyscallSucceeds());
+ std::string mount_opts = "fd=thiscantbeparsed";
+ TempPath mount_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(mount("fuse", mount_dir.path().c_str(), "fuse",
+ MS_NODEV | MS_NOSUID, mount_opts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FuseMount, NoDevice) {
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FuseMount, ClosedFD) {
+ FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY));
+ int fd = f.release();
+ close(fd);
+ std::string mopts = absl::StrCat("fd=", std::to_string(fd));
+
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, mopts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FuseMount, BadFD) {
+ const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ std::string mopts = absl::StrCat("fd=", std::to_string(fd.get()));
+
+ EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, mopts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index d3e5efd4f..f4af45e96 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -248,7 +248,7 @@ func (FilterOutputOwnerFail) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil {
- return fmt.Errorf("Invalid argument")
+ return fmt.Errorf("invalid argument")
}
return nil
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
index 49642f282..5d95516ee 100644
--- a/test/packetdrill/BUILD
+++ b/test/packetdrill/BUILD
@@ -38,6 +38,15 @@ packetdrill_test(
scripts = ["tcp_defer_accept_timeout.pkt"],
)
+test_suite(
+ name = "all_tests",
+ tags = [
+ "manual",
+ "packetdrill",
+ ],
+ tests = existing_rules(),
+)
+
bzl_library(
name = "defs_bzl",
srcs = ["defs.bzl"],
diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl
index fc28ce9ba..a6cbcc376 100644
--- a/test/packetdrill/defs.bzl
+++ b/test/packetdrill/defs.bzl
@@ -15,7 +15,7 @@ def _packetdrill_test_impl(ctx):
# Make sure that everything is readable here.
"find . -type f -exec chmod a+rx {} \\;",
"find . -type d -exec chmod a+rx {} \\;",
- "%s %s --init_script %s $@ -- %s\n" % (
+ "%s %s --init_script %s \"$@\" -- %s\n" % (
test_runner.short_path,
" ".join(ctx.attr.flags),
ctx.files._init_script[0].short_path,
@@ -80,9 +80,7 @@ def packetdrill_netstack_test(name, **kwargs):
kwargs["tags"] = PACKETDRILL_TAGS
_packetdrill_test(
name = name,
- # This is the default runtime unless
- # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
- flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"],
+ flags = ["--dut_platform", "netstack"],
**kwargs
)
diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh
index 922547d65..d25cad83a 100755
--- a/test/packetdrill/packetdrill_test.sh
+++ b/test/packetdrill/packetdrill_test.sh
@@ -29,7 +29,7 @@ function failure() {
}
trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
-declare -r LONGOPTS="dut_platform:,init_script:,runtime:"
+declare -r LONGOPTS="dut_platform:,init_script:,runtime:,partition:,total_partitions:"
# Don't use declare below so that the error from getopt will end the script.
PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
@@ -48,12 +48,17 @@ while true; do
shift 2
;;
--runtime)
- # Not readonly because there might be multiple --runtime arguments and we
- # want to use just the last one. Only used if --dut_platform is
- # "netstack".
declare RUNTIME="$2"
shift 2
;;
+ --partition)
+ # Ignored.
+ shift 2
+ ;;
+ --total_partitions)
+ # Ignored.
+ shift 2
+ ;;
--)
shift
break
diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD
index 605dd4972..888c44343 100644
--- a/test/packetimpact/runner/BUILD
+++ b/test/packetimpact/runner/BUILD
@@ -32,6 +32,7 @@ go_library(
deps = [
"//pkg/test/dockerutil",
"//test/packetimpact/netdevs",
+ "//test/packetimpact/testbench",
"@com_github_docker_docker//api/types/mount:go_default_library",
],
)
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index 60f0ebae3..c6c95546a 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -12,10 +12,11 @@ def _packetimpact_test_impl(ctx):
# current user, and no other users will be mapped in that namespace.
# Make sure that everything is readable here.
"find . -type f -or -type d -exec chmod a+rx {} \\;",
- "%s %s --testbench_binary %s $@\n" % (
+ "%s %s --testbench_binary %s --num_duts %d $@\n" % (
test_runner.short_path,
" ".join(ctx.attr.flags),
ctx.files.testbench_binary[0].short_path,
+ ctx.attr.num_duts,
),
])
ctx.actions.write(bench, bench_content, is_executable = True)
@@ -51,6 +52,10 @@ _packetimpact_test = rule(
mandatory = False,
default = [],
),
+ "num_duts": attr.int(
+ mandatory = False,
+ default = 1,
+ ),
},
test = True,
implementation = _packetimpact_test_impl,
@@ -110,24 +115,27 @@ def packetimpact_netstack_test(
**kwargs
)
-def packetimpact_go_test(name, expect_native_failure = False, expect_netstack_failure = False):
+def packetimpact_go_test(name, expect_native_failure = False, expect_netstack_failure = False, num_duts = 1):
"""Add packetimpact tests written in go.
Args:
name: name of the test
expect_native_failure: the test must fail natively
expect_netstack_failure: the test must fail for Netstack
+ num_duts: how many DUTs are needed for the test
"""
testbench_binary = name + "_test"
packetimpact_native_test(
name = name,
expect_failure = expect_native_failure,
testbench_binary = testbench_binary,
+ num_duts = num_duts,
)
packetimpact_netstack_test(
name = name,
expect_failure = expect_netstack_failure,
testbench_binary = testbench_binary,
+ num_duts = num_duts,
)
def packetimpact_testbench(name, size = "small", pure = True, **kwargs):
@@ -153,7 +161,7 @@ def packetimpact_testbench(name, size = "small", pure = True, **kwargs):
PacketimpactTestInfo = provider(
doc = "Provide information for packetimpact tests",
- fields = ["name", "expect_netstack_failure"],
+ fields = ["name", "expect_netstack_failure", "num_duts"],
)
ALL_TESTS = [
@@ -216,6 +224,9 @@ ALL_TESTS = [
name = "tcp_user_timeout",
),
PacketimpactTestInfo(
+ name = "tcp_zero_receive_window",
+ ),
+ PacketimpactTestInfo(
name = "tcp_queue_receive_in_syn_sent",
),
PacketimpactTestInfo(
@@ -255,6 +266,7 @@ ALL_TESTS = [
),
PacketimpactTestInfo(
name = "ipv6_fragment_icmp_error",
+ num_duts = 3,
),
PacketimpactTestInfo(
name = "udp_send_recv_dgram",
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
index ad1d73de2..3da265b78 100644
--- a/test/packetimpact/runner/dut.go
+++ b/test/packetimpact/runner/dut.go
@@ -17,6 +17,7 @@ package runner
import (
"context"
+ "encoding/json"
"flag"
"fmt"
"io/ioutil"
@@ -34,6 +35,7 @@ import (
"github.com/docker/docker/api/types/mount"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/packetimpact/netdevs"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
// stringList implements flag.Value.
@@ -56,9 +58,10 @@ var (
tshark = false
extraTestArgs = stringList{}
expectFailure = false
+ numDUTs = 1
- // DutAddr is the IP addres for DUT.
- DutAddr = net.IPv4(0, 0, 0, 10)
+ // DUTAddr is the IP addres for DUT.
+ DUTAddr = net.IPv4(0, 0, 0, 10)
testbenchAddr = net.IPv4(0, 0, 0, 20)
)
@@ -71,10 +74,15 @@ func RegisterFlags(fs *flag.FlagSet) {
fs.BoolVar(&tshark, "tshark", false, "use more verbose tshark in logs instead of tcpdump")
fs.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench")
fs.BoolVar(&expectFailure, "expect_failure", false, "expect that the test will fail when run")
+ fs.IntVar(&numDUTs, "num_duts", numDUTs, "the number of duts to create")
}
-// CtrlPort is the port that posix_server listens on.
-const CtrlPort = "40000"
+const (
+ // CtrlPort is the port that posix_server listens on.
+ CtrlPort uint16 = 40000
+ // testOutputDir is the directory in each container that holds test output.
+ testOutputDir = "/tmp/testoutput"
+)
// logger implements testutil.Logger.
//
@@ -95,16 +103,21 @@ func (l logger) Logf(format string, args ...interface{}) {
}
}
-// TestWithDUT runs a packetimpact test with the given information.
-func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT) {
- if testbenchBinary == "" {
- t.Fatal("--testbench_binary is missing")
- }
- dockerutil.EnsureSupportedDockerVersion()
+// dutInfo encapsulates all the essential information to set up testbench
+// container.
+type dutInfo struct {
+ dut DUT
+ ctrlNet, testNet *dockerutil.Network
+ netInfo *testbench.DUTTestNet
+}
- // Create the networks needed for the test. One control network is needed for
- // the gRPC control packets and one test network on which to transmit the test
- // packets.
+// setUpDUT will set up one DUT and return information for setting up the
+// container for testbench.
+func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerutil.Container) DUT) (dutInfo, error) {
+ // Create the networks needed for the test. One control network is needed
+ // for the gRPC control packets and one test network on which to transmit
+ // the test packets.
+ var info dutInfo
ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet"))
testNet := dockerutil.NewNetwork(ctx, logger("testNet"))
for _, dn := range []*dockerutil.Network{ctrlNet, testNet} {
@@ -113,8 +126,8 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
t.Log("creating docker network:", err)
const wait = 100 * time.Millisecond
t.Logf("sleeping %s and will try creating docker network again", wait)
- // This can fail if another docker network claimed the same IP so we'll
- // just try again.
+ // This can fail if another docker network claimed the same IP so we
+ // will just try again.
time.Sleep(wait)
continue
}
@@ -128,114 +141,203 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
})
// Sanity check.
if inspect, err := dn.Inspect(ctx); err != nil {
- t.Fatalf("failed to inspect network %s: %v", dn.Name, err)
+ return dutInfo{}, fmt.Errorf("failed to inspect network %s: %w", dn.Name, err)
} else if inspect.Name != dn.Name {
- t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
+ return dutInfo{}, fmt.Errorf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
}
}
-
- tmpDir, err := ioutil.TempDir("", "container-output")
- if err != nil {
- t.Fatal("creating temp dir:", err)
- }
- t.Cleanup(func() {
- if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
- t.Errorf("unable to copy container output files: %s", err)
- }
- if err := os.RemoveAll(tmpDir); err != nil {
- t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err)
- }
- })
-
- const testOutputDir = "/tmp/testoutput"
+ info.ctrlNet = ctrlNet
+ info.testNet = testNet
// Create the Docker container for the DUT.
- var dut *dockerutil.Container
+ var dut DUT
if native {
- dut = dockerutil.MakeNativeContainer(ctx, logger("dut"))
+ dut = mkDevice(dockerutil.MakeNativeContainer(ctx, logger(fmt.Sprintf("dut-%d", id))))
} else {
- dut = dockerutil.MakeContainer(ctx, logger("dut"))
+ dut = mkDevice(dockerutil.MakeContainer(ctx, logger(fmt.Sprintf("dut-%d", id))))
}
- t.Cleanup(func() {
- dut.CleanUp(ctx)
- })
+ info.dut = dut
runOpts := dockerutil.RunOpts{
Image: "packetimpact",
CapAdd: []string{"NET_ADMIN"},
- Mounts: []mount.Mount{{
- Type: mount.TypeBind,
- Source: tmpDir,
- Target: testOutputDir,
- ReadOnly: false,
- }},
+ }
+ if _, err := MountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil {
+ return dutInfo{}, err
}
- device := mkDevice(dut)
- remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet)
+ ipv4PrefixLength, _ := testNet.Subnet.Mask.Size()
+ remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev, err := dut.Prepare(ctx, t, runOpts, ctrlNet, testNet)
+ if err != nil {
+ return dutInfo{}, err
+ }
+ info.netInfo = &testbench.DUTTestNet{
+ RemoteMAC: remoteMAC,
+ RemoteIPv4: AddressInSubnet(DUTAddr, *testNet.Subnet),
+ RemoteIPv6: remoteIPv6,
+ RemoteDevID: dutDeviceID,
+ RemoteDevName: dutTestNetDev,
+ LocalIPv4: AddressInSubnet(testbenchAddr, *testNet.Subnet),
+ IPv4PrefixLength: ipv4PrefixLength,
+ POSIXServerIP: AddressInSubnet(DUTAddr, *ctrlNet.Subnet),
+ POSIXServerPort: CtrlPort,
+ }
+ return info, nil
+}
- // Create the Docker container for the testbench.
- testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
+// TestWithDUT runs a packetimpact test with the given information.
+func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT) {
+ if testbenchBinary == "" {
+ t.Fatal("--testbench_binary is missing")
+ }
+ dockerutil.EnsureSupportedDockerVersion()
- tbb := path.Base(testbenchBinary)
- containerTestbenchBinary := filepath.Join("/packetimpact", tbb)
- testbench.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb))
-
- // snifferNetDev is a network device on the test orchestrator that we will
- // run sniffer (tcpdump or tshark) on and inject traffic to, not to be
- // confused with the device on the DUT.
- const snifferNetDev = "eth2"
- // Run tcpdump in the test bench unbuffered, without DNS resolution, just on
- // the interface with the test packets.
- snifferArgs := []string{
- "tcpdump",
- "-S", "-vvv", "-U", "-n",
- "-i", snifferNetDev,
- "-w", testOutputDir + "/dump.pcap",
+ dutInfoChan := make(chan dutInfo, numDUTs)
+ errChan := make(chan error, numDUTs)
+ var dockerNetworks []*dockerutil.Network
+ var dutTestNets []*testbench.DUTTestNet
+ var duts []DUT
+
+ setUpCtx, cancelSetup := context.WithCancel(ctx)
+ t.Cleanup(cancelSetup)
+ for i := 0; i < numDUTs; i++ {
+ go func(i int) {
+ info, err := setUpDUT(setUpCtx, t, i, mkDevice)
+ if err != nil {
+ errChan <- err
+ } else {
+ dutInfoChan <- info
+ }
+ }(i)
}
- snifferRegex := "tcpdump: listening.*\n"
- if tshark {
- // Run tshark in the test bench unbuffered, without DNS resolution, just on
- // the interface with the test packets.
- snifferArgs = []string{
- "tshark", "-V", "-l", "-n", "-i", snifferNetDev,
- "-o", "tcp.check_checksum:TRUE",
- "-o", "udp.check_checksum:TRUE",
+ for i := 0; i < numDUTs; i++ {
+ select {
+ case info := <-dutInfoChan:
+ dockerNetworks = append(dockerNetworks, info.ctrlNet, info.testNet)
+ dutTestNets = append(dutTestNets, info.netInfo)
+ duts = append(duts, info.dut)
+ case err := <-errChan:
+ t.Fatal(err)
}
- snifferRegex = "Capturing on.*\n"
}
+ // Create the Docker container for the testbench.
+ testbenchContainer := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
+
+ runOpts := dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ }
+ if _, err := MountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil {
+ t.Fatal(err)
+ }
+ tbb := path.Base(testbenchBinary)
+ containerTestbenchBinary := filepath.Join("/packetimpact", tbb)
+ testbenchContainer.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb))
+
if err := StartContainer(
ctx,
runOpts,
- testbench,
+ testbenchContainer,
testbenchAddr,
- []*dockerutil.Network{ctrlNet, testNet},
- snifferArgs...,
+ dockerNetworks,
+ "tail", "-f", "/dev/null",
); err != nil {
- t.Fatalf("failed to start docker container for testbench sniffer: %s", err)
+ t.Fatalf("cannot start testbench container: %s", err)
}
- // Kill so that it will flush output.
- t.Cleanup(func() {
- time.Sleep(1 * time.Second)
- testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0])
- })
- if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil {
- t.Fatalf("sniffer on %s never listened: %s", dut.Name, err)
+ for i := range dutTestNets {
+ name, info, err := deviceByIP(ctx, testbenchContainer, dutTestNets[i].LocalIPv4)
+ if err != nil {
+ t.Fatalf("failed to get the device name associated with %s: %s", dutTestNets[i].LocalIPv4, err)
+ }
+ dutTestNets[i].LocalDevName = name
+ dutTestNets[i].LocalDevID = info.ID
+ dutTestNets[i].LocalMAC = info.MAC
+ localIPv6, err := getOrAssignIPv6Addr(ctx, testbenchContainer, name)
+ if err != nil {
+ t.Fatalf("failed to get IPV6 address on %s: %s", testbenchContainer.Name, err)
+ }
+ dutTestNets[i].LocalIPv6 = localIPv6
+ }
+ dutTestNetsBytes, err := json.Marshal(dutTestNets)
+ if err != nil {
+ t.Fatalf("failed to marshal %v into json: %s", dutTestNets, err)
}
- // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it
- // will respond with an RST. In most packetimpact tests, the SYN is sent
- // by the raw socket and the kernel knows nothing about the connection, this
- // behavior will break lots of TCP related packetimpact tests. To prevent
- // this, we can install the following iptables rules. The raw socket that
- // packetimpact tests use will still be able to see everything.
- for _, bin := range []string{"iptables", "ip6tables"} {
- if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", snifferNetDev, "-p", "tcp", "-j", "DROP"); err != nil {
- t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs)
+ baseSnifferArgs := []string{
+ "tcpdump",
+ "-vvv",
+ "--absolute-tcp-sequence-numbers",
+ "--packet-buffered",
+ // Disable DNS resolution.
+ "-n",
+ // run tcpdump as root since the output directory is owned by root. From
+ // `man tcpdump`:
+ //
+ // -Z user
+ // --relinquish-privileges=user
+ // If tcpdump is running as root, after opening the capture device
+ // or input savefile, change the user ID to user and the group ID to
+ // the primary group of user.
+ // This behavior is enabled by default (-Z tcpdump), and can be
+ // disabled by -Z root.
+ "-Z", "root",
+ }
+ if tshark {
+ baseSnifferArgs = []string{
+ "tshark",
+ "-V",
+ "-o", "tcp.check_checksum:TRUE",
+ "-o", "udp.check_checksum:TRUE",
+ // Disable buffering.
+ "-l",
+ // Disable DNS resolution.
+ "-n",
}
}
+ for _, n := range dutTestNets {
+ snifferArgs := append(baseSnifferArgs, "-i", n.LocalDevName)
+ if !tshark {
+ snifferArgs = append(
+ snifferArgs,
+ "-w",
+ filepath.Join(testOutputDir, fmt.Sprintf("%s.pcap", n.LocalDevName)),
+ )
+ }
+ p, err := testbenchContainer.ExecProcess(ctx, dockerutil.ExecOpts{}, snifferArgs...)
+ if err != nil {
+ t.Fatalf("failed to start exec a sniffer on %s: %s", n.LocalDevName, err)
+ }
+ t.Cleanup(func() {
+ if snifferOut, err := p.Logs(); err != nil {
+ t.Errorf("sniffer logs failed: %s\n%s", err, snifferOut)
+ } else {
+ t.Logf("sniffer logs:\n%s", snifferOut)
+ }
+ })
+ // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it
+ // will respond with an RST. In most packetimpact tests, the SYN is sent
+ // by the raw socket, the kernel knows nothing about the connection, this
+ // behavior will break lots of TCP related packetimpact tests. To prevent
+ // this, we can install the following iptables rules. The raw socket that
+ // packetimpact tests use will still be able to see everything.
+ for _, bin := range []string{"iptables", "ip6tables"} {
+ if logs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", n.LocalDevName, "-p", "tcp", "-j", "DROP"); err != nil {
+ t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbenchContainer.Name, err, logs)
+ }
+ }
+ }
+
+ t.Cleanup(func() {
+ // Wait 1 second before killing tcpdump to give it time to flush
+ // any packets. On linux tests killing it immediately can
+ // sometimes result in partial pcaps.
+ time.Sleep(1 * time.Second)
+ if logs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, "killall", baseSnifferArgs[0]); err != nil {
+ t.Errorf("failed to kill all sniffers: %s, logs: %s", err, logs)
+ }
+ })
// FIXME(b/156449515): Some piece of the system has a race. The old
// bash script version had a sleep, so we have one too. The race should
@@ -248,31 +350,29 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
testArgs := []string{containerTestbenchBinary}
testArgs = append(testArgs, extraTestArgs...)
testArgs = append(testArgs,
- "--posix_server_ip", AddressInSubnet(DutAddr, *ctrlNet.Subnet).String(),
- "--posix_server_port", CtrlPort,
- "--remote_ipv4", AddressInSubnet(DutAddr, *testNet.Subnet).String(),
- "--local_ipv4", AddressInSubnet(testbenchAddr, *testNet.Subnet).String(),
- "--remote_ipv6", remoteIPv6.String(),
- "--remote_mac", remoteMAC.String(),
- "--remote_interface_id", fmt.Sprintf("%d", dutDeviceID),
- "--local_device", snifferNetDev,
- "--remote_device", dutTestNetDev,
fmt.Sprintf("--native=%t", native),
+ "--dut_test_nets_json", string(dutTestNetsBytes),
)
- testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
+ testbenchLogs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
if (err != nil) != expectFailure {
var dutLogs string
- if logs, err := device.Logs(ctx); err != nil {
- dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
- } else {
- dutLogs = logs
+ for i, dut := range duts {
+ logs, err := dut.Logs(ctx)
+ if err != nil {
+ logs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
+ }
+ dutLogs = fmt.Sprintf(`%s====== Begin of DUT-%d Logs ======
+
+%s
+
+====== End of DUT-%d Logs ======
+
+`, dutLogs, i, logs, i)
}
t.Errorf(`test error: %v, expect failure: %t
-%s
-
-====== Begin of Testbench Logs ======
+%s====== Begin of Testbench Logs ======
%s
@@ -285,7 +385,9 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
type DUT interface {
// Prepare prepares the dut, starts posix_server and returns the IPv6, MAC
// address, the interface ID, and the interface name for the testNet on DUT.
- Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string)
+ // The t parameter is supposed to be used for t.Cleanup. Don't use it for
+ // t.Fatal/FailNow functions.
+ Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string, error)
// Logs retrieves the logs from the dut.
Logs(ctx context.Context) (string, error)
}
@@ -303,7 +405,7 @@ func NewDockerDUT(c *dockerutil.Container) DUT {
}
// Prepare implements DUT.Prepare.
-func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string) {
+func (dut *DockerDUT) Prepare(ctx context.Context, _ *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string, error) {
const containerPosixServerBinary = "/packetimpact/posix_server"
dut.c.CopyFiles(&runOpts, "/packetimpact", "test/packetimpact/dut/posix_server")
@@ -311,45 +413,31 @@ func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockeru
ctx,
runOpts,
dut.c,
- DutAddr,
+ DUTAddr,
[]*dockerutil.Network{ctrlNet, testNet},
containerPosixServerBinary,
"--ip=0.0.0.0",
- "--port="+CtrlPort,
+ fmt.Sprintf("--port=%d", CtrlPort),
); err != nil {
- t.Fatalf("failed to start docker container for DUT: %s", err)
+ return nil, nil, 0, "", fmt.Errorf("failed to start docker container for DUT: %w", err)
}
if _, err := dut.c.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil {
- t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err)
+ return nil, nil, 0, "", fmt.Errorf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err)
}
- dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(DutAddr, *testNet.Subnet))
+ dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(DUTAddr, *testNet.Subnet))
if err != nil {
- t.Fatal(err)
+ return nil, nil, 0, "", err
}
- remoteMAC := dutDeviceInfo.MAC
- remoteIPv6 := dutDeviceInfo.IPv6Addr
- // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
- // needed.
- if remoteIPv6 == nil {
- if _, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil {
- t.Fatalf("unable to ip addr add on container %s: %s", dut.c.Name, err)
- }
- // Now try again, to make sure that it worked.
- _, dutDeviceInfo, err = deviceByIP(ctx, dut.c, AddressInSubnet(DutAddr, *testNet.Subnet))
- if err != nil {
- t.Fatal(err)
- }
- remoteIPv6 = dutDeviceInfo.IPv6Addr
- if remoteIPv6 == nil {
- t.Fatalf("unable to set IPv6 address on container %s", dut.c.Name)
- }
+ remoteIPv6, err := getOrAssignIPv6Addr(ctx, dut.c, dutTestDevice)
+ if err != nil {
+ return nil, nil, 0, "", fmt.Errorf("failed to get IPv6 address on %s: %s", dut.c.Name, err)
}
const testNetDev = "eth2"
- return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev
+ return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev, nil
}
// Logs implements DUT.Logs.
@@ -358,11 +446,7 @@ func (dut *DockerDUT) Logs(ctx context.Context) (string, error) {
if err != nil {
return "", err
}
- return fmt.Sprintf(`====== Begin of DUT Logs ======
-
-%s
-
-====== End of DUT Logs ======`, logs), nil
+ return logs, nil
}
// AddNetworks connects docker network with the container and assigns the specific IP.
@@ -378,25 +462,35 @@ func AddNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, netw
}
// AddressInSubnet combines the subnet provided with the address and returns a
-// new address. The return address bits come from the subnet where the mask is 1
-// and from the ip address where the mask is 0.
+// new address. The return address bits come from the subnet where the mask is
+// 1 and from the ip address where the mask is 0.
func AddressInSubnet(addr net.IP, subnet net.IPNet) net.IP {
- var octets []byte
+ var octets net.IP
for i := 0; i < 4; i++ {
octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i])))
}
- return net.IP(octets)
+ return octets
}
-// deviceByIP finds a deviceInfo and device name from an IP address.
-func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
+// devicesInfo will run "ip addr show" on the container and parse the output
+// to a map[string]netdevs.DeviceInfo.
+func devicesInfo(ctx context.Context, d *dockerutil.Container) (map[string]netdevs.DeviceInfo, error) {
out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show")
if err != nil {
- return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out)
+ return map[string]netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out)
}
devs, err := netdevs.ParseDevices(out)
if err != nil {
- return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out)
+ return map[string]netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out)
+ }
+ return devs, nil
+}
+
+// deviceByIP finds a deviceInfo and device name from an IP address.
+func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
+ devs, err := devicesInfo(ctx, d)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, err
}
testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs)
if err != nil {
@@ -405,6 +499,36 @@ func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string
return testDevice, deviceInfo, nil
}
+// getOrAssignIPv6Addr will try to get the IPv6 address for the interface; if an
+// address was not assigned, a link-local address based on MAC will be assigned
+// to that interface.
+func getOrAssignIPv6Addr(ctx context.Context, d *dockerutil.Container, iface string) (net.IP, error) {
+ devs, err := devicesInfo(ctx, d)
+ if err != nil {
+ return net.IP{}, err
+ }
+ info := devs[iface]
+ if info.IPv6Addr != nil {
+ return info.IPv6Addr, nil
+ }
+ if info.MAC == nil {
+ return nil, fmt.Errorf("unable to find MAC address of %s", iface)
+ }
+ if logs, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(info.MAC).String(), "scope", "link", "dev", iface); err != nil {
+ return net.IP{}, fmt.Errorf("unable to ip addr add on container %s: %w, logs: %s", d.Name, err, logs)
+ }
+ // Now try again, to make sure that it worked.
+ devs, err = devicesInfo(ctx, d)
+ if err != nil {
+ return net.IP{}, err
+ }
+ info = devs[iface]
+ if info.IPv6Addr == nil {
+ return net.IP{}, fmt.Errorf("unable to set IPv6 address on container %s", d.Name)
+ }
+ return info.IPv6Addr, nil
+}
+
// createDockerNetwork makes a randomly-named network that will start with the
// namePrefix. The network will be a random /24 subnet.
func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
@@ -427,7 +551,7 @@ func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerut
hostconf.AutoRemove = true
hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
- if err := c.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ if err := c.CreateFrom(ctx, runOpts.Image, conf, hostconf, nil); err != nil {
return fmt.Errorf("unable to create container %s: %w", c.Name, err)
}
@@ -440,3 +564,30 @@ func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerut
}
return nil
}
+
+// MountTempDirectory creates a temporary directory on host with the template
+// and then mounts it into the container under the name provided. The temporary
+// directory name is returned. Content in that directory will be copied to
+// TEST_UNDECLARED_OUTPUTS_DIR in cleanup phase.
+func MountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) {
+ t.Helper()
+ tmpDir, err := ioutil.TempDir("", hostDirTemplate)
+ if err != nil {
+ return "", fmt.Errorf("failed to create a temp dir: %w", err)
+ }
+ t.Cleanup(func() {
+ if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
+ t.Errorf("unable to copy container output files: %s", err)
+ }
+ if err := os.RemoveAll(tmpDir); err != nil {
+ t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err)
+ }
+ })
+ runOpts.Mounts = append(runOpts.Mounts, mount.Mount{
+ Type: mount.TypeBind,
+ Source: tmpDir,
+ Target: containerDir,
+ ReadOnly: false,
+ })
+ return tmpDir, nil
+}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 5a0ee1367..983c2c030 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -21,7 +21,6 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
"//pkg/usermem",
- "//test/packetimpact/netdevs",
"//test/packetimpact/proto:posix_server_go_proto",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 266a8601c..576577310 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -306,11 +306,11 @@ func (s *tcpState) incoming(received Layer) Layer {
if s.remoteSeqNum != nil {
newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
}
- if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ if seq, flags := s.localSeqNum, tcpReceived.Flags; seq != nil && flags != nil && *flags&header.TCPFlagAck != 0 {
// The caller didn't specify an AckNum so we'll expect the calculated one,
// but only if the ACK flag is set because the AckNum is not valid in a
// header if ACK is not set.
- newIn.AckNum = Uint32(uint32(*s.localSeqNum))
+ newIn.AckNum = Uint32(uint32(*seq))
}
return &newIn
}
@@ -598,14 +598,14 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du
var errs error
for {
var gotLayers Layers
- if timeout = time.Until(deadline); timeout > 0 {
+ if timeout := time.Until(deadline); timeout > 0 {
gotLayers = conn.recvFrame(t, timeout)
}
if gotLayers == nil {
if errs == nil {
- return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout)
+ return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout)
}
- return nil, fmt.Errorf("got frames %w want %v during %s", errs, layers, timeout)
+ return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout)
}
if conn.match(layers, gotLayers) {
for i, s := range conn.layerStates {
@@ -615,7 +615,12 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du
}
return gotLayers, nil
}
- errs = multierr.Combine(errs, &layersError{got: gotLayers, want: conn.incoming(gotLayers)})
+ want := conn.incoming(layers)
+ if err := want.merge(layers); err != nil {
+ errs = multierr.Combine(errs, err)
+ } else {
+ errs = multierr.Combine(errs, &layersError{got: gotLayers, want: want})
+ }
}
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 7401a1991..19e6b8d7d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -298,14 +298,12 @@ func (l *IPv4) ToBytes() ([]byte, error) {
// An IPv4 header is variable length depending on the size of the Options.
hdrLen := header.IPv4MinimumSize
if l.Options != nil {
- hdrLen += l.Options.SizeWithPadding()
+ if len(*l.Options)%4 != 0 {
+ return nil, fmt.Errorf("invalid header options '%x (len=%d)'; must be 32 bit aligned", *l.Options, len(*l.Options))
+ }
+ hdrLen += len(*l.Options)
if hdrLen > header.IPv4MaximumHeaderSize {
- // While ToBytes can be called on packets that were received as well
- // as packets locally generated, it is physically impossible for a
- // received packet to overflow this value so any such failure must
- // be the result of a local programming error and not remotely
- // triggered. A panic is therefore appropriate.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", len(*l.Options), header.IPv4MaximumOptionsSize))
+ return nil, fmt.Errorf("IPv4 Options %d bytes, Max %d", len(*l.Options), header.IPv4MaximumOptionsSize)
}
}
b := make([]byte, hdrLen)
@@ -323,10 +321,6 @@ func (l *IPv4) ToBytes() ([]byte, error) {
DstAddr: tcpip.Address(""),
Options: nil,
}
- // Leave an empty options slice as nil.
- if hdrLen > header.IPv4MinimumSize {
- fields.Options = *l.Options
- }
if l.TOS != nil {
fields.TOS = *l.TOS
}
@@ -373,18 +367,31 @@ func (l *IPv4) ToBytes() ([]byte, error) {
if l.DstAddr != nil {
fields.DstAddr = *l.DstAddr
}
- if l.Checksum != nil {
- fields.Checksum = *l.Checksum
- }
+
h.Encode(fields)
- if l.Checksum == nil {
- h.SetChecksum(^h.CalculateChecksum())
+
+ // Put raw option bytes from test definition in header. Options as raw bytes
+ // allows us to serialize malformed options, which is not possible with
+ // the provided serialization functions.
+ if l.Options != nil {
+ h.SetHeaderLength(h.HeaderLength() + uint8(len(*l.Options)))
+ if got, want := copy(h.Options(), *l.Options), len(*l.Options); got != want {
+ return nil, fmt.Errorf("failed to copy option bytes into header, got %d want %d", got, want)
+ }
}
+
// Encode cannot set this incorrectly so we need to overwrite what it wrote
// in order to test handling of a bad IHL value.
if l.IHL != nil {
h.SetHeaderLength(*l.IHL)
}
+
+ if l.Checksum == nil {
+ h.SetChecksum(^h.CalculateChecksum())
+ } else {
+ h.SetChecksum(*l.Checksum)
+ }
+
return h, nil
}
@@ -498,13 +505,13 @@ func (l *IPv6) ToBytes() ([]byte, error) {
}
}
if l.NextHeader != nil {
- fields.NextHeader = *l.NextHeader
+ fields.TransportProtocol = tcpip.TransportProtocolNumber(*l.NextHeader)
} else {
nh, err := nextHeaderByLayer(l.next())
if err != nil {
return nil, err
}
- fields.NextHeader = nh
+ fields.TransportProtocol = tcpip.TransportProtocolNumber(nh)
}
if l.HopLimit != nil {
fields.HopLimit = *l.HopLimit
@@ -830,7 +837,9 @@ func (l *ICMPv6) ToBytes() ([]byte, error) {
if l.Code != nil {
h.SetCode(*l.Code)
}
- copy(h.NDPPayload(), l.Payload)
+ if n := copy(h.MessageBody(), l.Payload); n != len(l.Payload) {
+ panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, len(l.Payload)))
+ }
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
} else {
@@ -876,7 +885,7 @@ func parseICMPv6(b []byte) (Layer, layerParser) {
Type: ICMPv6Type(h.Type()),
Code: ICMPv6Code(h.Code()),
Checksum: Uint16(h.Checksum()),
- Payload: h.NDPPayload(),
+ Payload: h.MessageBody(),
}
return &icmpv6, nil
}
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
index 92200add9..891897d55 100644
--- a/test/packetimpact/testbench/testbench.go
+++ b/test/packetimpact/testbench/testbench.go
@@ -17,15 +17,13 @@
package testbench
import (
+ "encoding/json"
"flag"
"fmt"
"math/rand"
"net"
- "os/exec"
"testing"
"time"
-
- "gvisor.dev/gvisor/test/packetimpact/netdevs"
)
var (
@@ -36,25 +34,12 @@ var (
// RPCTimeout is the gRPC timeout.
RPCTimeout = 100 * time.Millisecond
+ // dutTestNetsJSON is the json string that describes all the test networks to
+ // duts available to use.
+ dutTestNetsJSON string
// dutTestNets is the pool among which the testbench can choose a DUT to work
// with.
dutTestNets chan *DUTTestNet
-
- // TODO(zeling): Remove the following variables once the test runner side is
- // ready.
- localDevice = ""
- remoteDevice = ""
- localIPv4 = ""
- remoteIPv4 = ""
- ipv4PrefixLength = 0
- localIPv6 = ""
- remoteIPv6 = ""
- localInterfaceID uint32
- remoteInterfaceID uint64
- localMAC = ""
- remoteMAC = ""
- posixServerIP = ""
- posixServerPort = 40000
)
// DUTTestNet describes the test network setup on dut and how the testbench
@@ -98,19 +83,10 @@ type DUTTestNet struct {
// exported variables above. It should be called by tests in their init
// functions.
func registerFlags(fs *flag.FlagSet) {
- fs.StringVar(&posixServerIP, "posix_server_ip", posixServerIP, "ip address to listen to for UDP commands")
- fs.IntVar(&posixServerPort, "posix_server_port", posixServerPort, "port to listen to for UDP commands")
- fs.StringVar(&localIPv4, "local_ipv4", localIPv4, "local IPv4 address for test packets")
- fs.StringVar(&remoteIPv4, "remote_ipv4", remoteIPv4, "remote IPv4 address for test packets")
- fs.StringVar(&remoteIPv6, "remote_ipv6", remoteIPv6, "remote IPv6 address for test packets")
- fs.StringVar(&remoteMAC, "remote_mac", remoteMAC, "remote mac address for test packets")
- fs.StringVar(&localDevice, "local_device", localDevice, "local device to inject traffic")
- fs.StringVar(&remoteDevice, "remote_device", remoteDevice, "remote device on the DUT")
- fs.Uint64Var(&remoteInterfaceID, "remote_interface_id", remoteInterfaceID, "remote interface ID for test packets")
-
fs.BoolVar(&Native, "native", Native, "whether the test is running natively")
fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout")
fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
+ fs.StringVar(&dutTestNetsJSON, "dut_test_nets_json", dutTestNetsJSON, "path to the dut test nets json file")
}
// Initialize initializes the testbench, it parse the flags and sets up the
@@ -118,61 +94,27 @@ func registerFlags(fs *flag.FlagSet) {
func Initialize(fs *flag.FlagSet) {
registerFlags(fs)
flag.Parse()
- if err := genPseudoFlags(); err != nil {
- panic(err)
- }
- var dut DUTTestNet
- var err error
- dut.LocalMAC, err = net.ParseMAC(localMAC)
- if err != nil {
- panic(err)
- }
- dut.RemoteMAC, err = net.ParseMAC(remoteMAC)
- if err != nil {
+ if err := loadDUTTestNets(); err != nil {
panic(err)
}
- dut.LocalIPv4 = net.ParseIP(localIPv4).To4()
- dut.LocalIPv6 = net.ParseIP(localIPv6).To16()
- dut.RemoteIPv4 = net.ParseIP(remoteIPv4).To4()
- dut.RemoteIPv6 = net.ParseIP(remoteIPv6).To16()
- dut.LocalDevID = uint32(localInterfaceID)
- dut.RemoteDevID = uint32(remoteInterfaceID)
- dut.LocalDevName = localDevice
- dut.RemoteDevName = remoteDevice
- dut.POSIXServerIP = net.ParseIP(posixServerIP)
- dut.POSIXServerPort = uint16(posixServerPort)
- dut.IPv4PrefixLength = ipv4PrefixLength
-
- dutTestNets = make(chan *DUTTestNet, 1)
- dutTestNets <- &dut
}
-// genPseudoFlags populates flag-like global config based on real flags.
-//
-// genPseudoFlags must only be called after flag.Parse.
-func genPseudoFlags() error {
- out, err := exec.Command("ip", "addr", "show").CombinedOutput()
- if err != nil {
- return fmt.Errorf("listing devices: %q: %w", string(out), err)
- }
- devs, err := netdevs.ParseDevices(string(out))
- if err != nil {
- return fmt.Errorf("parsing devices: %w", err)
+// loadDUTTestNets loads available DUT test networks from the json file, it
+// must be called after flag.Parse().
+func loadDUTTestNets() error {
+ var parsedTestNets []DUTTestNet
+ if err := json.Unmarshal([]byte(dutTestNetsJSON), &parsedTestNets); err != nil {
+ return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
-
- _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(localIPv4), devs)
- if err != nil {
- return fmt.Errorf("can't find deviceInfo: %w", err)
+ if got, want := len(parsedTestNets), 1; got < want {
+ return fmt.Errorf("got %d DUTs, the test requires at least %d DUTs", got, want)
}
-
- localMAC = deviceInfo.MAC.String()
- localIPv6 = deviceInfo.IPv6Addr.String()
- localInterfaceID = deviceInfo.ID
-
- if deviceInfo.IPv4Net != nil {
- ipv4PrefixLength, _ = deviceInfo.IPv4Net.Mask.Size()
- } else {
- ipv4PrefixLength, _ = net.ParseIP(localIPv4).DefaultMask().Size()
+ // Using a buffered channel as semaphore
+ dutTestNets = make(chan *DUTTestNet, len(parsedTestNets))
+ for i := range parsedTestNets {
+ parsedTestNets[i].LocalIPv4 = parsedTestNets[i].LocalIPv4.To4()
+ parsedTestNets[i].RemoteIPv4 = parsedTestNets[i].RemoteIPv4.To4()
+ dutTestNets <- &parsedTestNets[i]
}
return nil
}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 33bd070c1..b1b3c578b 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -366,9 +366,29 @@ packetimpact_testbench(
],
)
+packetimpact_testbench(
+ name = "tcp_zero_receive_window",
+ srcs = ["tcp_zero_receive_window_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
validate_all_tests()
[packetimpact_go_test(
name = t.name,
expect_netstack_failure = hasattr(t, "expect_netstack_failure"),
+ num_duts = t.num_duts if hasattr(t, "num_duts") else 1,
) for t in ALL_TESTS]
+
+test_suite(
+ name = "all_tests",
+ tags = [
+ "manual",
+ "packetimpact",
+ ],
+ tests = existing_rules(),
+)
diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
index e00a7aba2..d2203082d 100644
--- a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
+++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
@@ -34,10 +34,10 @@ type fragmentInfo struct {
offset uint16
size uint16
more uint8
+ id uint16
}
func TestIPv4FragmentReassembly(t *testing.T) {
- const fragmentID = 42
icmpv4ProtoNum := uint8(header.ICMPv4ProtocolNumber)
tests := []struct {
@@ -45,28 +45,75 @@ func TestIPv4FragmentReassembly(t *testing.T) {
ipPayloadLen int
fragments []fragmentInfo
expectReply bool
+ skip bool
+ skipReason string
}{
{
description: "basic reassembly",
- ipPayloadLen: 2000,
+ ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments},
- {offset: 1000, size: 1000, more: 0},
+ {offset: 0, size: 1000, id: 5, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 5, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 5, more: 0},
},
expectReply: true,
},
{
description: "out of order fragments",
- ipPayloadLen: 2000,
+ ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 1000, size: 1000, more: 0},
- {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 6, more: 0},
+ {offset: 0, size: 1000, id: 6, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 6, more: header.IPv4FlagMoreFragments},
},
expectReply: true,
},
+ {
+ description: "duplicated fragments",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 7, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 7, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 7, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 7, more: 0},
+ },
+ expectReply: true,
+ skip: true,
+ skipReason: "gvisor.dev/issues/4971",
+ },
+ {
+ description: "fragment subset",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 8, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 8, more: header.IPv4FlagMoreFragments},
+ {offset: 512, size: 256, id: 8, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 8, more: 0},
+ },
+ expectReply: true,
+ skip: true,
+ skipReason: "gvisor.dev/issues/4971",
+ },
+ {
+ description: "fragment overlap",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 9, more: header.IPv4FlagMoreFragments},
+ {offset: 1512, size: 1000, id: 9, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, id: 9, more: header.IPv4FlagMoreFragments},
+ {offset: 2000, size: 1000, id: 9, more: 0},
+ },
+ expectReply: false,
+ skip: true,
+ skipReason: "gvisor.dev/issues/4971",
+ },
}
for _, test := range tests {
+ if test.skip {
+ t.Skip("%s test skipped: %s", test.description, test.skipReason)
+ continue
+ }
t.Run(test.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
conn := dut.Net.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{})
@@ -95,7 +142,7 @@ func TestIPv4FragmentReassembly(t *testing.T) {
Protocol: &icmpv4ProtoNum,
FragmentOffset: testbench.Uint16(fragment.offset),
Flags: testbench.Uint8(fragment.more),
- ID: testbench.Uint16(fragmentID),
+ ID: testbench.Uint16(fragment.id),
},
&testbench.Payload{
Bytes: data[fragment.offset:][:fragment.size],
@@ -114,7 +161,7 @@ func TestIPv4FragmentReassembly(t *testing.T) {
}, time.Second)
if err != nil {
// Either an unexpected frame was received, or none at all.
- if bytesReceived < test.ipPayloadLen {
+ if test.expectReply && bytesReceived < test.ipPayloadLen {
t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err)
}
break
diff --git a/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go b/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go
index 0ddc1526f..a37867e85 100644
--- a/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go
+++ b/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go
@@ -119,6 +119,7 @@ func TestIPv6ICMPEchoRequestFragmentReassembly(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
dut := testbench.NewDUT(t)
ipv6Conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
conn := (*testbench.Connection)(&ipv6Conn)
@@ -220,6 +221,7 @@ func TestIPv6FragmentReassemblyTimeout(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
dut := testbench.NewDUT(t)
ipv6Conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
conn := (*testbench.Connection)(&ipv6Conn)
@@ -315,6 +317,7 @@ func TestIPv6FragmentParamProblem(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
dut := testbench.NewDUT(t)
ipv6Conn := dut.Net.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
conn := (*testbench.Connection)(&ipv6Conn)
diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
index 65f742f55..dd98ee7a1 100644
--- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
+++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
@@ -35,10 +35,10 @@ type fragmentInfo struct {
offset uint16
size uint16
more bool
+ id uint32
}
func TestIPv6FragmentReassembly(t *testing.T) {
- const fragmentID = 42
icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)
tests := []struct {
@@ -49,10 +49,11 @@ func TestIPv6FragmentReassembly(t *testing.T) {
}{
{
description: "basic reassembly",
- ipPayloadLen: 1500,
+ ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 0, size: 760, more: true},
- {offset: 760, size: 740, more: false},
+ {offset: 0, size: 1000, id: 100, more: true},
+ {offset: 1000, size: 1000, id: 100, more: true},
+ {offset: 2000, size: 1000, id: 100, more: false},
},
expectReply: true,
},
@@ -60,12 +61,45 @@ func TestIPv6FragmentReassembly(t *testing.T) {
description: "out of order fragments",
ipPayloadLen: 3000,
fragments: []fragmentInfo{
- {offset: 0, size: 1024, more: true},
- {offset: 2048, size: 952, more: false},
- {offset: 1024, size: 1024, more: true},
+ {offset: 0, size: 1000, id: 101, more: true},
+ {offset: 2000, size: 1000, id: 101, more: false},
+ {offset: 1000, size: 1000, id: 101, more: true},
+ },
+ expectReply: true,
+ },
+ {
+ description: "duplicated fragments",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 102, more: true},
+ {offset: 1000, size: 1000, id: 102, more: true},
+ {offset: 1000, size: 1000, id: 102, more: true},
+ {offset: 2000, size: 1000, id: 102, more: false},
+ },
+ expectReply: true,
+ },
+ {
+ description: "fragment subset",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 103, more: true},
+ {offset: 1000, size: 1000, id: 103, more: true},
+ {offset: 512, size: 256, id: 103, more: true},
+ {offset: 2000, size: 1000, id: 103, more: false},
},
expectReply: true,
},
+ {
+ description: "fragment overlap",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, id: 104, more: true},
+ {offset: 1512, size: 1000, id: 104, more: true},
+ {offset: 1000, size: 1000, id: 104, more: true},
+ {offset: 2000, size: 1000, id: 104, more: false},
+ },
+ expectReply: false,
+ },
}
for _, test := range tests {
@@ -101,7 +135,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
NextHeader: &icmpv6ProtoNum,
FragmentOffset: testbench.Uint16(fragment.offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
MoreFragments: testbench.Bool(fragment.more),
- Identification: testbench.Uint32(fragmentID),
+ Identification: testbench.Uint32(fragment.id),
},
&testbench.Payload{
Bytes: data[fragment.offset:][:fragment.size],
@@ -118,7 +152,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
}, time.Second)
if err != nil {
// Either an unexpected frame was received, or none at all.
- if bytesReceived < test.ipPayloadLen {
+ if test.expectReply && bytesReceived < test.ipPayloadLen {
t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err)
}
break
diff --git a/test/packetimpact/tests/tcp_zero_receive_window_test.go b/test/packetimpact/tests/tcp_zero_receive_window_test.go
new file mode 100644
index 000000000..d06690705
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_receive_window_test.go
@@ -0,0 +1,125 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_receive_window_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestZeroReceiveWindow tests if the DUT sends a zero receive window eventually.
+func TestZeroReceiveWindow(t *testing.T) {
+ for _, payloadLen := range []int{64, 512, 1024} {
+ t.Run(fmt.Sprintf("TestZeroReceiveWindow_with_%dbytes_payload", payloadLen), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)}
+ // Expect the DUT to eventually advertise zero receive window.
+ // The test would timeout otherwise.
+ for readOnce := false; ; {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+ // Read once to trigger the subsequent window update from the
+ // DUT to grow the right edge of the receive window from what
+ // was advertised in the SYN-ACK. This ensures that we test
+ // for the full default buffer size (1MB on gVisor at the time
+ // of writing this comment), thus testing for cases when the
+ // scaled receive window size ends up > 65535 (0xffff).
+ if !readOnce {
+ if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen {
+ t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen)
+ }
+ readOnce = true
+ }
+ windowSize := *gotTCP.WindowSize
+ t.Logf("got window size = %d", windowSize)
+ if windowSize == 0 {
+ break
+ }
+ }
+ })
+ }
+}
+
+// TestNonZeroReceiveWindow tests for the DUT to never send a zero receive
+// window when the data is being read from the socket buffer.
+func TestNonZeroReceiveWindow(t *testing.T) {
+ for _, payloadLen := range []int{64, 512, 1024} {
+ t.Run(fmt.Sprintf("TestZeroReceiveWindow_with_%dbytes_payload", payloadLen), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)}
+ var rcvWindow uint16
+ initRcv := false
+ // This loop keeps a running rcvWindow value from the initial ACK for the data
+ // we sent. Once we have received ACKs with non-zero receive windows, we break
+ // the loop.
+ for {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+ if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen {
+ t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen)
+ }
+ if *gotTCP.WindowSize == 0 {
+ t.Fatalf("expected non-zero receive window.")
+ }
+ if !initRcv {
+ rcvWindow = uint16(*gotTCP.WindowSize)
+ initRcv = true
+ }
+ if rcvWindow <= uint16(payloadLen) {
+ break
+ }
+ rcvWindow -= uint16(payloadLen)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
index 71cde6cde..b29c07825 100644
--- a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
+++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
@@ -44,7 +44,7 @@ func TestUDPRecvMcastBcast(t *testing.T) {
{bound: subnetBcastAddr, to: subnetBcastAddr},
- // FIXME(gvisor.dev/issues/4896): Previously by the time subnetBcastAddr is
+ // FIXME(gvisor.dev/issue/4896): Previously by the time subnetBcastAddr is
// created, IPv4PrefixLength is still 0 because genPseudoFlags is not called
// yet, it was only called in NewDUT, so the test didn't do what the author
// original intended to and becomes failing because we process all flags at
diff --git a/test/perf/BUILD b/test/perf/BUILD
index b763be50e..e25f090ae 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -1,3 +1,4 @@
+load("//tools:defs.bzl", "more_shards")
load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
@@ -37,7 +38,7 @@ syscall_test(
syscall_test(
size = "enormous",
debug = False,
- shard_count = 10,
+ shard_count = more_shards,
tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
)
diff --git a/test/root/BUILD b/test/root/BUILD
index a9130b34f..8d9fff578 100644
--- a/test/root/BUILD
+++ b/test/root/BUILD
@@ -1,5 +1,4 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/vm:defs.bzl", "vm_test")
package(licenses = ["notice"])
@@ -24,12 +23,8 @@ go_test(
],
library = ":root",
tags = [
- # Requires docker and runsc to be configured before the test runs.
- # Also, the test needs to be run as root. Note that below, the
- # root_vm_test relies on the default runtime 'runsc' being installed by
- # the default installer.
- "manual",
"local",
+ "manual",
],
visibility = ["//:sandbox"],
deps = [
@@ -46,10 +41,3 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
-
-vm_test(
- name = "root_vm_test",
- size = "large",
- shard_count = 1,
- targets = [":root_test"],
-)
diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go
index a26b83081..a74d6b1c1 100644
--- a/test/root/cgroup_test.go
+++ b/test/root/cgroup_test.go
@@ -249,12 +249,11 @@ func TestCgroup(t *testing.T) {
case "pids-limit":
val := attr.value
hostconf.Resources.PidsLimit = &val
-
}
}
// Create container.
- if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ if err := d.CreateFrom(ctx, "basic/alpine", conf, hostconf, nil); err != nil {
t.Fatalf("create failed with: %v", err)
}
@@ -323,7 +322,7 @@ func TestCgroupParent(t *testing.T) {
}, "sleep", "10000")
hostconf.Resources.CgroupParent = parent
- if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ if err := d.CreateFrom(ctx, "basic/alpine", conf, hostconf, nil); err != nil {
t.Fatalf("create failed with: %v", err)
}
diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go
index 735dff107..df52dd381 100644
--- a/test/root/crictl_test.go
+++ b/test/root/crictl_test.go
@@ -315,7 +315,7 @@ const (
// v1 is the containerd API v1.
v1 string = "v1"
- // v1 is the containerd API v21.
+ // v2 is the containerd API v2.
v2 string = "v2"
)
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 7618f6a21..829247657 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -12,7 +12,7 @@ def _runner_test_impl(ctx):
" mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
" chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
"fi",
- "exec %s %s %s\n" % (
+ "exec %s %s \"$@\" %s\n" % (
ctx.files.runner[0].short_path,
" ".join(ctx.attr.runner_args),
ctx.files.test[0].short_path,
@@ -52,8 +52,6 @@ _runner_test = rule(
def _syscall_test(
test,
- shard_count,
- size,
platform,
use_tmpfs,
tags,
@@ -63,7 +61,8 @@ def _syscall_test(
overlay = False,
add_uds_tree = False,
vfs2 = False,
- fuse = False):
+ fuse = False,
+ **kwargs):
# Prepend "runsc" to non-native platform names.
full_platform = platform if platform == "native" else "runsc_" + platform
@@ -126,15 +125,12 @@ def _syscall_test(
name = name,
test = test,
runner_args = runner_args,
- size = size,
tags = tags,
- shard_count = shard_count,
+ **kwargs
)
def syscall_test(
test,
- shard_count = 5,
- size = "small",
use_tmpfs = False,
add_overlay = False,
add_uds_tree = False,
@@ -142,18 +138,21 @@ def syscall_test(
vfs2 = True,
fuse = False,
debug = True,
- tags = None):
+ tags = None,
+ **kwargs):
"""syscall_test is a macro that will create targets for all platforms.
Args:
test: the test target.
- shard_count: shards for defined tests.
- size: the defined test size.
use_tmpfs: use tmpfs in the defined tests.
add_overlay: add an overlay test.
add_uds_tree: add a UDS test.
add_hostinet: add a hostinet test.
+ vfs2: enable VFS2 support.
+ fuse: enable FUSE support.
+ debug: enable debug output.
tags: starting test tags.
+ **kwargs: additional test arguments.
"""
if not tags:
tags = []
@@ -173,8 +172,6 @@ def syscall_test(
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
@@ -182,6 +179,7 @@ def syscall_test(
debug = debug,
vfs2 = True,
fuse = fuse,
+ **kwargs
)
if fuse:
# Only generate *_vfs2_fuse target if fuse parameter is enabled.
@@ -189,38 +187,35 @@ def syscall_test(
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = "native",
use_tmpfs = False,
add_uds_tree = add_uds_tree,
tags = list(tags),
debug = debug,
+ **kwargs
)
for (platform, platform_tags) in platforms.items():
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platform_tags + tags,
debug = debug,
+ **kwargs
)
if add_overlay:
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
overlay = True,
+ **kwargs
)
# TODO(gvisor.dev/issue/4407): Remove tags to enable VFS2 overlay tests.
@@ -230,8 +225,6 @@ def syscall_test(
overlay_vfs2_tags.append("notap")
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
@@ -239,38 +232,35 @@ def syscall_test(
debug = debug,
overlay = True,
vfs2 = True,
+ **kwargs
)
if add_hostinet:
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
network = "host",
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
+ **kwargs
)
if not use_tmpfs:
# Also test shared gofer access.
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
file_access = "shared",
+ **kwargs
)
_syscall_test(
test = test,
- shard_count = shard_count,
- size = size,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
@@ -278,4 +268,5 @@ def syscall_test(
debug = debug,
file_access = "shared",
vfs2 = True,
+ **kwargs
)
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index 22b526f59..510ffe013 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "bzl_library")
+load("//tools:defs.bzl", "bzl_library", "more_shards", "most_shards")
load("//test/runtimes:defs.bzl", "runtime_test")
package(licenses = ["notice"])
@@ -7,7 +7,7 @@ runtime_test(
name = "go1.12",
exclude_file = "exclude/go1.12.csv",
lang = "go",
- shard_count = 8,
+ shard_count = more_shards,
)
runtime_test(
@@ -15,28 +15,28 @@ runtime_test(
batch = 100,
exclude_file = "exclude/java11.csv",
lang = "java",
- shard_count = 16,
+ shard_count = most_shards,
)
runtime_test(
name = "nodejs12.4.0",
exclude_file = "exclude/nodejs12.4.0.csv",
lang = "nodejs",
- shard_count = 8,
+ shard_count = most_shards,
)
runtime_test(
name = "php7.3.6",
exclude_file = "exclude/php7.3.6.csv",
lang = "php",
- shard_count = 8,
+ shard_count = more_shards,
)
runtime_test(
name = "python3.7.3",
exclude_file = "exclude/python3.7.3.csv",
lang = "python",
- shard_count = 8,
+ shard_count = more_shards,
)
bzl_library(
diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go
index 64e6e14db..f2db5f9ea 100644
--- a/test/runtimes/runner/lib/lib.go
+++ b/test/runtimes/runner/lib/lib.go
@@ -34,12 +34,7 @@ import (
// RunTests is a helper that is called by main. It exists so that we can run
// defered functions before exiting. It returns an exit code that should be
// passed to os.Exit.
-func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, batchSize int, timeout time.Duration) int {
- if partitionNum <= 0 || totalPartitions <= 0 || partitionNum > totalPartitions {
- fmt.Fprintf(os.Stderr, "invalid partition %d of %d", partitionNum, totalPartitions)
- return 1
- }
-
+func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int {
// TODO(gvisor.dev/issue/1624): Remove those tests from all exclude lists
// that only fail with VFS1.
@@ -63,7 +58,7 @@ func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, ba
// Get a slice of tests to run. This will also start a single Docker
// container that will be used to run each test. The final test will
// stop the Docker container.
- tests, err := getTests(ctx, d, lang, image, partitionNum, totalPartitions, batchSize, timeout, excludes)
+ tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
return 1
@@ -74,7 +69,7 @@ func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, ba
}
// getTests executes all tests as table tests.
-func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, partitionNum, totalPartitions, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) {
+func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) {
// Start the container.
opts := dockerutil.RunOpts{
Image: fmt.Sprintf("runtimes/%s", image),
@@ -90,18 +85,9 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string,
return nil, fmt.Errorf("docker exec failed: %v", err)
}
- // Calculate a subset of tests to run corresponding to the current
- // shard.
+ // Calculate a subset of tests.
tests := strings.Fields(list)
sort.Strings(tests)
-
- partitionSize := len(tests) / totalPartitions
- if partitionNum == totalPartitions {
- tests = tests[(partitionNum-1)*partitionSize:]
- } else {
- tests = tests[(partitionNum-1)*partitionSize : partitionNum*partitionSize]
- }
-
indices, err := testutil.TestIndicesForShard(len(tests))
if err != nil {
return nil, fmt.Errorf("TestsForShard() failed: %v", err)
@@ -122,6 +108,10 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string,
}
tcs = append(tcs, tests[tc])
}
+ if len(tcs) == 0 {
+ // No tests to add to this batch.
+ continue
+ }
itests = append(itests, testing.InternalTest{
Name: strings.Join(tcs, ", "),
F: func(t *testing.T) {
@@ -206,3 +196,4 @@ func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
func (f testDeps) ImportPath() string { return "" }
func (f testDeps) StartTestLog(io.Writer) {}
func (f testDeps) StopTestLog() error { return nil }
+func (f testDeps) SetPanicOnExit0(bool) {}
diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go
index 5b3443e36..ec79a22c2 100644
--- a/test/runtimes/runner/main.go
+++ b/test/runtimes/runner/main.go
@@ -25,13 +25,11 @@ import (
)
var (
- lang = flag.String("lang", "", "language runtime to test")
- image = flag.String("image", "", "docker image with runtime tests")
- excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
- partition = flag.Int("partition", 1, "partition number, this is 1-indexed")
- totalPartitions = flag.Int("total_partitions", 1, "total number of partitions")
- batchSize = flag.Int("batch", 50, "number of test cases run in one command")
- timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout")
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
+ batchSize = flag.Int("batch", 50, "number of test cases run in one command")
+ timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout")
)
func main() {
@@ -40,5 +38,5 @@ func main() {
fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
os.Exit(1)
}
- os.Exit(lib.RunTests(*lang, *image, *excludeFile, *partition, *totalPartitions, *batchSize, *timeout))
+ os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout))
}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index b5a4ef4df..0da35f7be 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -1,3 +1,4 @@
+load("//tools:defs.bzl", "more_shards", "most_shards")
load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
@@ -12,7 +13,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:accept_bind_test",
)
@@ -32,7 +33,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:alarm_test",
)
@@ -66,7 +67,7 @@ syscall_test(
size = "large",
# Produce too many logs in the debug mode.
debug = False,
- shard_count = 50,
+ shard_count = most_shards,
# Takes too long for TSAN. Since this is kind of a stress test that doesn't
# involve much concurrency, TSAN's usefulness here is limited anyway.
tags = ["nogotsan"],
@@ -211,7 +212,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:futex_test",
)
@@ -258,7 +259,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:itimer_test",
)
@@ -313,7 +314,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:mmap_test",
)
@@ -347,6 +348,7 @@ syscall_test(
syscall_test(
add_overlay = True,
+ shard_count = more_shards,
test = "//test/syscalls/linux:open_test",
)
@@ -376,7 +378,7 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:pipe_test",
)
@@ -448,7 +450,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:pty_test",
)
@@ -475,6 +477,7 @@ syscall_test(
)
syscall_test(
+ shard_count = more_shards,
test = "//test/syscalls/linux:raw_socket_test",
)
@@ -490,7 +493,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:readv_socket_test",
)
@@ -539,7 +542,7 @@ syscall_test(
)
syscall_test(
- shard_count = 20,
+ shard_count = more_shards,
test = "//test/syscalls/linux:semaphore_test",
)
@@ -594,7 +597,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_abstract_test",
)
@@ -605,7 +608,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_domain_test",
)
@@ -618,55 +621,62 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_filesystem_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_inet_loopback_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
# Takes too long for TSAN. Creates a lot of TCP sockets.
tags = ["nogotsan"],
test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test",
)
syscall_test(
+ test = "//test/syscalls/linux:socket_ipv4_udp_unbound_external_networking_test",
+)
+
+syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test",
)
syscall_test(
size = "medium",
+ add_hostinet = True,
test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_tcp_loopback_test",
)
syscall_test(
size = "medium",
- shard_count = 50,
+ add_hostinet = True,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test",
)
syscall_test(
size = "medium",
+ add_hostinet = True,
test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_ip_udp_loopback_test",
)
@@ -677,6 +687,13 @@ syscall_test(
syscall_test(
size = "medium",
+ test = "//test/syscalls/linux:socket_ipv6_udp_unbound_loopback_test",
+)
+
+syscall_test(
+ size = "medium",
+ add_hostinet = True,
+ shard_count = more_shards,
# Takes too long under gotsan to run.
tags = ["nogotsan"],
test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_nogotsan_test",
@@ -691,6 +708,7 @@ syscall_test(
)
syscall_test(
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_ip_unbound_test",
)
@@ -723,6 +741,7 @@ syscall_test(
)
syscall_test(
+ add_hostinet = True,
test = "//test/syscalls/linux:socket_non_stream_blocking_local_test",
)
@@ -753,7 +772,7 @@ syscall_test(
syscall_test(
# NOTE(b/116636318): Large sendmsg may stall a long time.
size = "enormous",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_dgram_local_test",
)
@@ -765,14 +784,14 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_unix_pair_test",
)
syscall_test(
# NOTE(b/116636318): Large sendmsg may stall a long time.
size = "enormous",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_seqpacket_local_test",
)
@@ -798,13 +817,13 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 10,
+ shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_unbound_seqpacket_test",
)
syscall_test(
size = "large",
- shard_count = 50,
+ shard_count = most_shards,
test = "//test/syscalls/linux:socket_unix_unbound_stream_test",
)
@@ -858,7 +877,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 10,
+ shard_count = more_shards,
test = "//test/syscalls/linux:tcp_socket_test",
)
@@ -867,6 +886,7 @@ syscall_test(
)
syscall_test(
+ shard_count = more_shards,
test = "//test/syscalls/linux:timerfd_test",
)
@@ -897,13 +917,14 @@ syscall_test(
)
syscall_test(
+ add_hostinet = True,
test = "//test/syscalls/linux:udp_bind_test",
)
syscall_test(
size = "medium",
add_hostinet = True,
- shard_count = 10,
+ shard_count = more_shards,
test = "//test/syscalls/linux:udp_socket_test",
)
@@ -947,7 +968,7 @@ syscall_test(
syscall_test(
size = "medium",
- shard_count = 5,
+ shard_count = more_shards,
test = "//test/syscalls/linux:wait_test",
)
@@ -961,6 +982,7 @@ syscall_test(
)
syscall_test(
+ add_hostinet = True,
test = "//test/syscalls/linux:proc_net_tcp_test",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 946d06cd6..017f997de 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -20,7 +20,9 @@ exports_files(
"socket_ip_udp_loopback_nonblock.cc",
"socket_ip_unbound.cc",
"socket_ipv4_udp_unbound_external_networking_test.cc",
+ "socket_ipv6_udp_unbound_external_networking_test.cc",
"socket_ipv4_udp_unbound_loopback.cc",
+ "socket_ipv6_udp_unbound_loopback.cc",
"socket_ipv4_udp_unbound_loopback_nogotsan.cc",
"tcp_socket.cc",
"udp_bind.cc",
@@ -621,10 +623,7 @@ cc_binary(
cc_binary(
name = "exceptions_test",
testonly = 1,
- srcs = select_arch(
- amd64 = ["exceptions.cc"],
- arm64 = [],
- ),
+ srcs = ["exceptions.cc"],
linkstatic = 1,
deps = [
gtest,
@@ -799,8 +798,8 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:cleanup",
- "//test/util:epoll_util",
"//test/util:eventfd_util",
+ "//test/util:file_descriptor",
"//test/util:fs_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/flags:flag",
@@ -811,6 +810,7 @@ cc_binary(
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:save_util",
+ "//test/util:signal_util",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
@@ -946,6 +946,7 @@ cc_binary(
"//test/util:eventfd_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/strings",
gtest,
@@ -2453,6 +2454,27 @@ cc_library(
)
cc_library(
+ name = "socket_ipv6_udp_unbound_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ipv6_udp_unbound.cc",
+ ],
+ hdrs = [
+ "socket_ipv6_udp_unbound.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "@com_google_absl//absl/memory",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "socket_ipv4_udp_unbound_netlink_test_cases",
testonly = 1,
srcs = [
@@ -2490,6 +2512,23 @@ cc_library(
)
cc_library(
+ name = "socket_ip_udp_unbound_external_networking",
+ testonly = 1,
+ srcs = [
+ "socket_ip_udp_unbound_external_networking.cc",
+ ],
+ hdrs = [
+ "socket_ip_udp_unbound_external_networking.h",
+ ],
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "//test/util:test_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "socket_ipv4_udp_unbound_external_networking_test_cases",
testonly = 1,
srcs = [
@@ -2499,10 +2538,24 @@ cc_library(
"socket_ipv4_udp_unbound_external_networking.h",
],
deps = [
- ":ip_socket_test_util",
- ":socket_test_util",
+ ":socket_ip_udp_unbound_external_networking",
+ gtest,
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "socket_ipv6_udp_unbound_external_networking_test_cases",
+ testonly = 1,
+ srcs = [
+ "socket_ipv6_udp_unbound_external_networking.cc",
+ ],
+ hdrs = [
+ "socket_ipv6_udp_unbound_external_networking.h",
+ ],
+ deps = [
+ ":socket_ip_udp_unbound_external_networking",
gtest,
- "//test/util:test_util",
],
alwayslink = 1,
)
@@ -2702,6 +2755,22 @@ cc_binary(
)
cc_binary(
+ name = "socket_ipv6_udp_unbound_external_networking_test",
+ testonly = 1,
+ srcs = [
+ "socket_ipv6_udp_unbound_external_networking_test.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ipv6_udp_unbound_external_networking_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "socket_bind_to_device_test",
testonly = 1,
srcs = [
@@ -2792,6 +2861,22 @@ cc_binary(
)
cc_binary(
+ name = "socket_ipv6_udp_unbound_loopback_test",
+ testonly = 1,
+ srcs = [
+ "socket_ipv6_udp_unbound_loopback.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_ipv6_udp_unbound_test_cases",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "socket_ipv4_udp_unbound_loopback_nogotsan_test",
testonly = 1,
srcs = [
@@ -3288,6 +3373,7 @@ cc_binary(
":socket_test_util",
":unix_domain_socket_test_util",
gtest,
+ "//test/util:file_descriptor",
"//test/util:test_main",
"//test/util:test_util",
],
diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc
index 82223f997..5530ad18f 100644
--- a/test/syscalls/linux/chown.cc
+++ b/test/syscalls/linux/chown.cc
@@ -75,7 +75,16 @@ TEST_P(ChownParamTest, ChownFileSucceeds) {
if (num_groups > 0) {
std::vector<gid_t> list(num_groups);
EXPECT_THAT(getgroups(list.size(), list.data()), SyscallSucceeds());
- gid = list[0];
+ // Scan the list of groups for a valid gid. Note that if a group is not
+ // defined in this local user namespace, then we will see 65534, and the
+ // group will not chown below as expected. So only change if we find a
+ // valid group in this list.
+ for (const gid_t other_gid : list) {
+ if (other_gid != 65534) {
+ gid = other_gid;
+ break;
+ }
+ }
}
EXPECT_NO_ERRNO(GetParam()(file.path(), geteuid(), gid));
diff --git a/test/syscalls/linux/exceptions.cc b/test/syscalls/linux/exceptions.cc
index 420b9543f..11dc1c651 100644
--- a/test/syscalls/linux/exceptions.cc
+++ b/test/syscalls/linux/exceptions.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+#if defined(__x86_64__)
// Default value for the x87 FPU control word. See Intel SDM Vol 1, Ch 8.1.5
// "x87 FPU Control Word".
constexpr uint16_t kX87ControlWordDefault = 0x37f;
@@ -93,6 +94,9 @@ void InIOHelper(int width, int value) {
},
::testing::KilledBySignal(SIGSEGV), "");
}
+#elif defined(__aarch64__)
+void inline Halt() { asm("hlt #0\r\n"); }
+#endif
TEST(ExceptionTest, Halt) {
// In order to prevent the regular handler from messing with things (and
@@ -102,9 +106,14 @@ TEST(ExceptionTest, Halt) {
sa.sa_handler = SIG_DFL;
auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGSEGV, sa));
+#if defined(__x86_64__)
EXPECT_EXIT(Halt(), ::testing::KilledBySignal(SIGSEGV), "");
+#elif defined(__aarch64__)
+ EXPECT_EXIT(Halt(), ::testing::KilledBySignal(SIGILL), "");
+#endif
}
+#if defined(__x86_64__)
TEST(ExceptionTest, DivideByZero) {
// See above.
struct sigaction sa = {};
@@ -362,6 +371,7 @@ TEST(ExceptionTest, Int3Compact) {
EXPECT_EXIT(Int3Compact(), ::testing::KilledBySignal(SIGTRAP), "");
}
+#endif
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 34016d4bd..4b581045b 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -14,10 +14,13 @@
#include <fcntl.h>
#include <signal.h>
+#include <sys/epoll.h>
#include <sys/types.h>
#include <syscall.h>
#include <unistd.h>
+#include <atomic>
+#include <deque>
#include <iostream>
#include <list>
#include <string>
@@ -34,25 +37,27 @@
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/cleanup.h"
#include "test/util/eventfd_util.h"
+#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
#include "test/util/save_util.h"
+#include "test/util/signal_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
#include "test/util/timer_util.h"
-ABSL_FLAG(std::string, child_setlock_on, "",
+ABSL_FLAG(std::string, child_set_lock_on, "",
"Contains the path to try to set a file lock on.");
-ABSL_FLAG(bool, child_setlock_write, false,
+ABSL_FLAG(bool, child_set_lock_write, false,
"Whether to set a writable lock (otherwise readable)");
ABSL_FLAG(bool, blocking, false,
"Whether to set a blocking lock (otherwise non-blocking).");
ABSL_FLAG(bool, retry_eintr, false,
"Whether to retry in the subprocess on EINTR.");
-ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start");
-ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len");
+ABSL_FLAG(uint64_t, child_set_lock_start, 0, "The value of struct flock start");
+ABSL_FLAG(uint64_t, child_set_lock_len, 0, "The value of struct flock len");
ABSL_FLAG(int32_t, socket_fd, -1,
"A socket to use for communicating more state back "
"to the parent.");
@@ -60,6 +65,11 @@ ABSL_FLAG(int32_t, socket_fd, -1,
namespace gvisor {
namespace testing {
+std::function<void(int, siginfo_t*, void*)> setsig_signal_handle;
+void setsig_signal_handler(int signum, siginfo_t* siginfo, void* ucontext) {
+ setsig_signal_handle(signum, siginfo, ucontext);
+}
+
class FcntlLockTest : public ::testing::Test {
public:
void SetUp() override {
@@ -84,18 +94,93 @@ class FcntlLockTest : public ::testing::Test {
int fds_[2] = {};
};
+struct SignalDelivery {
+ int num;
+ siginfo_t info;
+};
+
+class FcntlSignalTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ int pipe_fds[2];
+ ASSERT_THAT(pipe2(pipe_fds, O_NONBLOCK), SyscallSucceeds());
+ pipe_read_fd_ = pipe_fds[0];
+ pipe_write_fd_ = pipe_fds[1];
+ }
+
+ PosixErrorOr<Cleanup> RegisterSignalHandler(int signum) {
+ struct sigaction handler;
+ handler.sa_sigaction = setsig_signal_handler;
+ setsig_signal_handle = [&](int signum, siginfo_t* siginfo,
+ void* unused_ucontext) {
+ SignalDelivery sig;
+ sig.num = signum;
+ sig.info = *siginfo;
+ signals_received_.push_back(sig);
+ num_signals_received_++;
+ };
+ sigemptyset(&handler.sa_mask);
+ handler.sa_flags = SA_SIGINFO;
+ return ScopedSigaction(signum, handler);
+ }
+
+ void FlushAndCloseFD(int fd) {
+ char buf;
+ int read_bytes;
+ do {
+ read_bytes = read(fd, &buf, 1);
+ } while (read_bytes > 0);
+ // read() can also fail with EWOULDBLOCK since the pipe is open in
+ // non-blocking mode. This is not an error.
+ EXPECT_TRUE(read_bytes == 0 || (read_bytes == -1 && errno == EWOULDBLOCK));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ }
+
+ void DupReadFD() {
+ ASSERT_THAT(pipe_read_fd_dup_ = dup(pipe_read_fd_), SyscallSucceeds());
+ max_expected_signals++;
+ }
+
+ void RegisterFD(int fd, int signum) {
+ ASSERT_THAT(fcntl(fd, F_SETOWN, getpid()), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd, F_SETSIG, signum), SyscallSucceeds());
+ int old_flags;
+ ASSERT_THAT(old_flags = fcntl(fd, F_GETFL), SyscallSucceeds());
+ ASSERT_THAT(fcntl(fd, F_SETFL, old_flags | O_ASYNC), SyscallSucceeds());
+ }
+
+ void GenerateIOEvent() {
+ ASSERT_THAT(write(pipe_write_fd_, "test", 4), SyscallSucceedsWithValue(4));
+ }
+
+ void WaitForSignalDelivery(absl::Duration timeout) {
+ absl::Time wait_start = absl::Now();
+ while (num_signals_received_ < max_expected_signals &&
+ absl::Now() - wait_start < timeout) {
+ absl::SleepFor(absl::Milliseconds(10));
+ }
+ }
+
+ int pipe_read_fd_ = -1;
+ int pipe_read_fd_dup_ = -1;
+ int pipe_write_fd_ = -1;
+ int max_expected_signals = 1;
+ std::deque<SignalDelivery> signals_received_;
+ std::atomic<int> num_signals_received_ = 0;
+};
+
namespace {
PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write,
bool blocking, bool retry_eintr, int fd,
off_t start, off_t length, pid_t* child) {
std::vector<std::string> args = {
- "/proc/self/exe", "--child_setlock_on", path,
- "--child_setlock_start", absl::StrCat(start), "--child_setlock_len",
- absl::StrCat(length), "--socket_fd", absl::StrCat(fd)};
+ "/proc/self/exe", "--child_set_lock_on", path,
+ "--child_set_lock_start", absl::StrCat(start), "--child_set_lock_len",
+ absl::StrCat(length), "--socket_fd", absl::StrCat(fd)};
if (for_write) {
- args.push_back("--child_setlock_write");
+ args.push_back("--child_set_lock_write");
}
if (blocking) {
@@ -965,7 +1050,6 @@ TEST(FcntlTest, GetOwnNone) {
// into F_{GET,SET}OWN_EX.
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(0));
- MaybeSave();
}
TEST(FcntlTest, GetOwnExNone) {
@@ -1009,7 +1093,6 @@ TEST(FcntlTest, SetOwnPid) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(pid));
- MaybeSave();
}
TEST(FcntlTest, SetOwnPgrp) {
@@ -1030,7 +1113,6 @@ TEST(FcntlTest, SetOwnPgrp) {
SyscallSucceedsWithValue(0));
EXPECT_EQ(got_owner.type, F_OWNER_PGRP);
EXPECT_EQ(got_owner.pid, pgid);
- MaybeSave();
}
TEST(FcntlTest, SetOwnUnset) {
@@ -1058,7 +1140,6 @@ TEST(FcntlTest, SetOwnUnset) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(0));
- MaybeSave();
}
// F_SETOWN flips the sign of negative values, an operation that is guarded
@@ -1130,7 +1211,6 @@ TEST(FcntlTest, SetOwnExTid) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(owner.pid));
- MaybeSave();
}
TEST(FcntlTest, SetOwnExPid) {
@@ -1146,7 +1226,6 @@ TEST(FcntlTest, SetOwnExPid) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(owner.pid));
- MaybeSave();
}
TEST(FcntlTest, SetOwnExPgrp) {
@@ -1168,7 +1247,6 @@ TEST(FcntlTest, SetOwnExPgrp) {
SyscallSucceedsWithValue(0));
EXPECT_EQ(got_owner.type, set_owner.type);
EXPECT_EQ(got_owner.pid, set_owner.pid);
- MaybeSave();
}
TEST(FcntlTest, SetOwnExUnset) {
@@ -1201,7 +1279,6 @@ TEST(FcntlTest, SetOwnExUnset) {
EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
SyscallSucceedsWithValue(0));
- MaybeSave();
}
TEST(FcntlTest, GetOwnExTid) {
@@ -1258,9 +1335,269 @@ TEST(FcntlTest, GetOwnExPgrp) {
EXPECT_EQ(got_owner.pid, set_owner.pid);
}
+TEST(FcntlTest, SetSig) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGUSR1),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(SIGUSR1));
+}
+
+TEST(FcntlTest, SetSigDefaultsToZero) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Defaults to returning the zero value, indicating default behavior (SIGIO).
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, SetSigToDefault) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGIO),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(SIGIO));
+
+ // Can be reset to the default behavior.
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, 0),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, SetSigInvalid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGRTMAX + 1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST(FcntlTest, SetSigInvalidDoesNotResetPreviousChoice) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGUSR1),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGRTMAX + 1),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETSIG),
+ SyscallSucceedsWithValue(SIGUSR1));
+}
+
+TEST_F(FcntlSignalTest, SetSigDefault) {
+ const auto signal_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ RegisterFD(pipe_read_fd_, 0); // Zero = default behavior
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ signals_received_.pop_front();
+ EXPECT_EQ(sig.num, SIGIO);
+ EXPECT_EQ(sig.info.si_signo, SIGIO);
+ // siginfo contents is undefined in this case.
+}
+
+TEST_F(FcntlSignalTest, SetSigCustom) {
+ const auto signal_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ signals_received_.pop_front();
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigUnregisterStillGetsSigio) {
+ const auto sigio_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ RegisterFD(pipe_read_fd_, 0);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ signals_received_.pop_front();
+ EXPECT_EQ(sig.num, SIGIO);
+ // siginfo contents is undefined in this case.
+}
+
+TEST_F(FcntlSignalTest, SetSigWithSigioStillGetsSiginfo) {
+ const auto signal_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ RegisterFD(pipe_read_fd_, SIGIO);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ EXPECT_EQ(sig.num, SIGIO);
+ EXPECT_EQ(sig.info.si_signo, SIGIO);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupThenCloseOld) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ FlushAndCloseFD(pipe_read_fd_);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the **old** FD (even though it is closed).
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupThenCloseNew) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ FlushAndCloseFD(pipe_read_fd_dup_);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the old FD.
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupOldRegistered) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the old FD.
+ EXPECT_EQ(sig.num, SIGUSR1);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR1);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupNewRegistered) {
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the new FD.
+ EXPECT_EQ(sig.num, SIGUSR2);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR2);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_dup_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupBothRegistered) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the **new** signal number, but the **old** FD.
+ EXPECT_EQ(sig.num, SIGUSR2);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR2);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupBothRegisteredAfterDup) {
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ DupReadFD();
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with the **new** signal number, but the **old** FD.
+ EXPECT_EQ(sig.num, SIGUSR2);
+ EXPECT_EQ(sig.info.si_signo, SIGUSR2);
+ EXPECT_EQ(sig.info.si_fd, pipe_read_fd_);
+ EXPECT_EQ(sig.info.si_band, EPOLLIN | EPOLLRDNORM);
+}
+
+TEST_F(FcntlSignalTest, SetSigDupUnregisterOld) {
+ const auto sigio_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ RegisterFD(pipe_read_fd_, 0); // Should go back to SIGIO behavior.
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with SIGIO.
+ EXPECT_EQ(sig.num, SIGIO);
+ // siginfo is undefined in this case.
+}
+
+TEST_F(FcntlSignalTest, SetSigDupUnregisterNew) {
+ const auto sigio_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGIO));
+ const auto sigusr1_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR1));
+ const auto sigusr2_cleanup =
+ ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGUSR2));
+ RegisterFD(pipe_read_fd_, SIGUSR1);
+ DupReadFD();
+ RegisterFD(pipe_read_fd_dup_, SIGUSR2);
+ RegisterFD(pipe_read_fd_dup_, 0); // Should go back to SIGIO behavior.
+ GenerateIOEvent();
+ WaitForSignalDelivery(absl::Seconds(1));
+ ASSERT_EQ(num_signals_received_, 1);
+ SignalDelivery sig = signals_received_.front();
+ // We get a signal with SIGIO.
+ EXPECT_EQ(sig.num, SIGIO);
+ // siginfo is undefined in this case.
+}
+
// Make sure that making multiple concurrent changes to async signal generation
// does not cause any race issues.
-TEST(FcntlTest, SetFlSetOwnDoNotRace) {
+TEST(FcntlTest, SetFlSetOwnSetSigDoNotRace) {
FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
@@ -1268,32 +1605,40 @@ TEST(FcntlTest, SetFlSetOwnDoNotRace) {
EXPECT_THAT(pid = getpid(), SyscallSucceeds());
constexpr absl::Duration runtime = absl::Milliseconds(300);
- auto setAsync = [&s, &runtime] {
+ auto set_async = [&s, &runtime] {
for (auto start = absl::Now(); absl::Now() - start < runtime;) {
ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, O_ASYNC),
SyscallSucceeds());
sched_yield();
}
};
- auto resetAsync = [&s, &runtime] {
+ auto reset_async = [&s, &runtime] {
for (auto start = absl::Now(); absl::Now() - start < runtime;) {
ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, 0), SyscallSucceeds());
sched_yield();
}
};
- auto setOwn = [&s, &pid, &runtime] {
+ auto set_own = [&s, &pid, &runtime] {
for (auto start = absl::Now(); absl::Now() - start < runtime;) {
ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid),
SyscallSucceeds());
sched_yield();
}
};
+ auto set_sig = [&s, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETSIG, SIGUSR1),
+ SyscallSucceeds());
+ sched_yield();
+ }
+ };
std::list<ScopedThread> threads;
for (int i = 0; i < 10; i++) {
- threads.emplace_back(setAsync);
- threads.emplace_back(resetAsync);
- threads.emplace_back(setOwn);
+ threads.emplace_back(set_async);
+ threads.emplace_back(reset_async);
+ threads.emplace_back(set_own);
+ threads.emplace_back(set_sig);
}
}
@@ -1302,57 +1647,60 @@ TEST(FcntlTest, SetFlSetOwnDoNotRace) {
} // namespace testing
} // namespace gvisor
-int main(int argc, char** argv) {
- gvisor::testing::TestInit(&argc, &argv);
-
- const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on);
- if (!setlock_on.empty()) {
- int socket_fd = absl::GetFlag(FLAGS_socket_fd);
- int fd = open(setlock_on.c_str(), O_RDWR, 0666);
- if (fd == -1 && errno != 0) {
- int err = errno;
- std::cerr << "CHILD open " << setlock_on << " failed " << err
- << std::endl;
- exit(err);
- }
+int set_lock() {
+ const std::string set_lock_on = absl::GetFlag(FLAGS_child_set_lock_on);
+ int socket_fd = absl::GetFlag(FLAGS_socket_fd);
+ int fd = open(set_lock_on.c_str(), O_RDWR, 0666);
+ if (fd == -1 && errno != 0) {
+ int err = errno;
+ std::cerr << "CHILD open " << set_lock_on << " failed: " << err
+ << std::endl;
+ return err;
+ }
- struct flock fl;
- if (absl::GetFlag(FLAGS_child_setlock_write)) {
- fl.l_type = F_WRLCK;
- } else {
- fl.l_type = F_RDLCK;
- }
- fl.l_whence = SEEK_SET;
- fl.l_start = absl::GetFlag(FLAGS_child_setlock_start);
- fl.l_len = absl::GetFlag(FLAGS_child_setlock_len);
+ struct flock fl;
+ if (absl::GetFlag(FLAGS_child_set_lock_write)) {
+ fl.l_type = F_WRLCK;
+ } else {
+ fl.l_type = F_RDLCK;
+ }
+ fl.l_whence = SEEK_SET;
+ fl.l_start = absl::GetFlag(FLAGS_child_set_lock_start);
+ fl.l_len = absl::GetFlag(FLAGS_child_set_lock_len);
+
+ // Test the fcntl.
+ int err = 0;
+ int ret = 0;
+
+ gvisor::testing::MonotonicTimer timer;
+ timer.Start();
+ do {
+ ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
+ } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
+ auto usec = absl::ToInt64Microseconds(timer.Duration());
+
+ if (ret == -1 && errno != 0) {
+ err = errno;
+ std::cerr << "CHILD lock " << set_lock_on << " failed " << err << std::endl;
+ }
- // Test the fcntl.
- int err = 0;
- int ret = 0;
+ // If there is a socket fd let's send back the time in microseconds it took
+ // to execute this syscall.
+ if (socket_fd != -1) {
+ gvisor::testing::WriteFd(socket_fd, reinterpret_cast<void*>(&usec),
+ sizeof(usec));
+ close(socket_fd);
+ }
- gvisor::testing::MonotonicTimer timer;
- timer.Start();
- do {
- ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
- } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
- auto usec = absl::ToInt64Microseconds(timer.Duration());
-
- if (ret == -1 && errno != 0) {
- err = errno;
- std::cerr << "CHILD lock " << setlock_on << " failed " << err
- << std::endl;
- }
+ close(fd);
+ return err;
+}
- // If there is a socket fd let's send back the time in microseconds it took
- // to execute this syscall.
- if (socket_fd != -1) {
- gvisor::testing::WriteFd(socket_fd, reinterpret_cast<void*>(&usec),
- sizeof(usec));
- close(socket_fd);
- }
+int main(int argc, char** argv) {
+ gvisor::testing::TestInit(&argc, &argv);
- close(fd);
- exit(err);
+ if (!absl::GetFlag(FLAGS_child_set_lock_on).empty()) {
+ exit(set_lock());
}
return gvisor::testing::RunAllTests();
diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc
index b040cdcf7..93c692dd6 100644
--- a/test/syscalls/linux/getdents.cc
+++ b/test/syscalls/linux/getdents.cc
@@ -32,6 +32,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
#include "absl/container/node_hash_set.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
@@ -381,7 +382,7 @@ TYPED_TEST(GetdentsTest, PartialBuffer) {
// getdents iterates correctly despite mutation of /proc/self/fd.
TYPED_TEST(GetdentsTest, ProcSelfFd) {
constexpr size_t kNfds = 10;
- std::unordered_map<int, FileDescriptor> fds;
+ absl::node_hash_map<int, FileDescriptor> fds;
fds.reserve(kNfds);
for (size_t i = 0; i < kNfds; i++) {
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc
index db29bd59c..5d1735853 100644
--- a/test/syscalls/linux/kill.cc
+++ b/test/syscalls/linux/kill.cc
@@ -58,6 +58,12 @@ void SigHandler(int sig, siginfo_t* info, void* context) { _exit(0); }
// If pid equals -1, then sig is sent to every process for which the calling
// process has permission to send signals, except for process 1 (init).
TEST(KillTest, CanKillAllPIDs) {
+ // If we're not running inside the sandbox, then we skip this test
+ // as our namespace may contain may more processes that cannot tolerate
+ // the signal below. We also cannot reliably create a new pid namespace
+ // for ourselves and test the same functionality.
+ SKIP_IF(!IsRunningOnGvisor());
+
int pipe_fds[2];
ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
FileDescriptor read_fd(pipe_fds[0]);
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
index d65b7d031..15b645fb7 100644
--- a/test/syscalls/linux/mount.cc
+++ b/test/syscalls/linux/mount.cc
@@ -345,42 +345,6 @@ TEST(MountTest, RenameRemoveMountPoint) {
ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY));
}
-TEST(MountTest, MountFuseFilesystemNoDevice) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
- SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled());
-
- auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
-
- // Before kernel version 4.16-rc6, FUSE mount is protected by
- // capable(CAP_SYS_ADMIN). After this version, it uses
- // ns_capable(CAP_SYS_ADMIN) to protect. Before the 4.16 kernel, it was not
- // allowed to mount fuse file systems without the global CAP_SYS_ADMIN.
- int res = mount("", dir.path().c_str(), "fuse", 0, "");
- SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM);
-
- EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""),
- SyscallFailsWithErrno(EINVAL));
-}
-
-TEST(MountTest, MountFuseFilesystem) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
- SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled());
-
- const FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY));
- std::string mopts = "fd=" + std::to_string(fd.get());
-
- auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
-
- // See comments in MountFuseFilesystemNoDevice for the reason why we skip
- // EPERM when running on Linux.
- int res = mount("", dir.path().c_str(), "fuse", 0, "");
- SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM);
-
- auto const mount =
- ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0));
-}
-
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index 77f390f3c..fcd162ca2 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -505,6 +505,18 @@ TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) {
EXPECT_THAT(open(bad_path.c_str(), O_RDONLY), SyscallFailsWithErrno(ENOTDIR));
}
+TEST_F(OpenTest, OpenWithStrangeFlags) {
+ // VFS1 incorrectly allows read/write operations on such file descriptors.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY | O_RDWR));
+ EXPECT_THAT(write(fd.get(), "x", 1), SyscallFailsWithErrno(EBADF));
+ char c;
+ EXPECT_THAT(read(fd.get(), &c, 1), SyscallFailsWithErrno(EBADF));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 78c36f98f..9d63782fb 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -112,14 +112,6 @@ TEST(CreateTest, CreatFileWithOTruncAndReadOnly) {
ASSERT_THAT(close(dirfd), SyscallSucceeds());
}
-TEST(CreateTest, CreateFailsOnUnpermittedDir) {
- // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
- // always override directory permissions.
- ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
- ASSERT_THAT(open("/foo", O_CREAT | O_RDWR, 0644),
- SyscallFailsWithErrno(EACCES));
-}
-
TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) {
// Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
// always override directory permissions.
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index 06d9dbf65..01ccbdcd2 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -71,13 +71,13 @@ class PipeTest : public ::testing::TestWithParam<PipeCreator> {
// Returns true iff the pipe represents a named pipe.
bool IsNamedPipe() const { return named_pipe_; }
- int Size() const {
+ size_t Size() const {
int s1 = fcntl(rfd_.get(), F_GETPIPE_SZ);
int s2 = fcntl(wfd_.get(), F_GETPIPE_SZ);
EXPECT_GT(s1, 0);
EXPECT_GT(s2, 0);
EXPECT_EQ(s1, s2);
- return s1;
+ return static_cast<size_t>(s1);
}
static void TearDownTestSuite() {
@@ -568,7 +568,7 @@ TEST_P(PipeTest, Streaming) {
DisableSave ds;
// Size() requires 2 syscalls, call it once and remember the value.
- const int pipe_size = Size();
+ const size_t pipe_size = Size();
const size_t streamed_bytes = 4 * pipe_size;
absl::Notification notify;
@@ -576,7 +576,7 @@ TEST_P(PipeTest, Streaming) {
std::vector<char> buf(1024);
// Don't start until it's full.
notify.WaitForNotification();
- ssize_t total = 0;
+ size_t total = 0;
while (total < streamed_bytes) {
ASSERT_THAT(read(rfd_.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(buf.size()));
@@ -593,7 +593,7 @@ TEST_P(PipeTest, Streaming) {
// page) for the check for notify.Notify() below to be correct.
std::vector<char> buf(1024);
RandomizeBuffer(buf.data(), buf.size());
- ssize_t total = 0;
+ size_t total = 0;
while (total < streamed_bytes) {
ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(buf.size()));
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index f43a41891..e508ce27f 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -17,6 +17,7 @@
#include <fcntl.h>
#include <limits.h>
#include <linux/magic.h>
+#include <linux/sem.h>
#include <sched.h>
#include <signal.h>
#include <stddef.h>
@@ -1801,6 +1802,33 @@ TEST(ProcPidCmdline, SubprocessForkSameCmdline) {
}
}
+TEST(ProcPidCmdline, SubprocessSeekCmdline) {
+ FileDescriptor fd;
+ ASSERT_NO_ERRNO(WithSubprocess(
+ [&](int pid) -> PosixError {
+ // Running. Open /proc/pid/cmdline.
+ ASSIGN_OR_RETURN_ERRNO(
+ fd, Open(absl::StrCat("/proc/", pid, "/cmdline"), O_RDONLY));
+ return NoError();
+ },
+ [&](int pid) -> PosixError {
+ // Zombie, but seek should still succeed.
+ int ret = lseek(fd.get(), 0x801, 0);
+ if (ret < 0) {
+ return PosixError(errno);
+ }
+ return NoError();
+ },
+ [&](int pid) -> PosixError {
+ // Exited.
+ int ret = lseek(fd.get(), 0x801, 0);
+ if (ret < 0) {
+ return PosixError(errno);
+ }
+ return NoError();
+ }));
+}
+
// Test whether /proc/PID/ symlinks can be read for a running process.
TEST(ProcPidSymlink, SubprocessRunning) {
char buf[1];
@@ -2409,6 +2437,28 @@ TEST(ProcFilesystems, PresenceOfShmMaxMniAll) {
ASSERT_LE(shmall, ULONG_MAX - (1UL << 24));
}
+TEST(ProcFilesystems, PresenceOfSem) {
+ uint32_t semmsl = 0;
+ uint32_t semmns = 0;
+ uint32_t semopm = 0;
+ uint32_t semmni = 0;
+ std::string proc_file;
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/sem"));
+ ASSERT_FALSE(proc_file.empty());
+ std::vector<absl::string_view> sem_limits =
+ absl::StrSplit(proc_file, absl::ByAnyChar("\t"), absl::SkipWhitespace());
+ ASSERT_EQ(sem_limits.size(), 4);
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[0], &semmsl));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[1], &semmns));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[2], &semopm));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[3], &semmni));
+
+ ASSERT_EQ(semmsl, SEMMSL);
+ ASSERT_EQ(semmns, SEMMNS);
+ ASSERT_EQ(semopm, SEMOPM);
+ ASSERT_EQ(semmni, SEMMNI);
+}
+
// Check that /proc/mounts is a symlink to self/mounts.
TEST(ProcMounts, IsSymlink) {
auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/mounts"));
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 23677e296..1cc700fe7 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -420,14 +420,14 @@ TEST(ProcNetSnmp, CheckNetStat) {
int name_count = 0;
int value_count = 0;
std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
- for (int i = 0; i + 1 < lines.size(); i += 2) {
+ for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) {
std::vector<absl::string_view> names =
absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
std::vector<absl::string_view> values =
absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
<< "' and '" << lines[i + 1] << "'";
- for (int j = 0; j < names.size() && j < values.size(); ++j) {
+ for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) {
if (names[j] == "TCPOrigDataSent" || names[j] == "TCPSynRetrans" ||
names[j] == "TCPDSACKRecv" || names[j] == "TCPDSACKOfoRecv") {
++name_count;
@@ -457,14 +457,14 @@ TEST(ProcNetSnmp, CheckSnmp) {
int name_count = 0;
int value_count = 0;
std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
- for (int i = 0; i + 1 < lines.size(); i += 2) {
+ for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) {
std::vector<absl::string_view> names =
absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
std::vector<absl::string_view> values =
absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
<< "' and '" << lines[i + 1] << "'";
- for (int j = 0; j < names.size() && j < values.size(); ++j) {
+ for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) {
if (names[j] == "RetransSegs") {
++name_count;
int64_t val;
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
index a63067586..662c6feb2 100644
--- a/test/syscalls/linux/proc_net_unix.cc
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -181,7 +181,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
// Returns true on match, and sets 'match' to point to the matching entry.
bool FindBy(std::vector<UnixEntry> entries, UnixEntry* match,
std::function<bool(const UnixEntry&)> predicate) {
- for (int i = 0; i < entries.size(); ++i) {
+ for (long unsigned int i = 0; i < entries.size(); ++i) {
if (predicate(entries[i])) {
*match = entries[i];
return true;
diff --git a/test/syscalls/linux/proc_pid_uid_gid_map.cc b/test/syscalls/linux/proc_pid_uid_gid_map.cc
index 748f7be58..af052a63c 100644
--- a/test/syscalls/linux/proc_pid_uid_gid_map.cc
+++ b/test/syscalls/linux/proc_pid_uid_gid_map.cc
@@ -203,7 +203,8 @@ TEST_P(ProcSelfUidGidMapTest, IdentityMapOwnID) {
EXPECT_THAT(
InNewUserNamespaceWithMapFD([&](int fd) {
DenySelfSetgroups();
- TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size());
+ TEST_PCHECK(static_cast<long unsigned int>(
+ write(fd, line.c_str(), line.size())) == line.size());
}),
IsPosixErrorOkAndHolds(0));
}
@@ -220,7 +221,8 @@ TEST_P(ProcSelfUidGidMapTest, TrailingNewlineAndNULIgnored) {
DenySelfSetgroups();
// The write should return the full size of the write, even though
// characters after the NUL were ignored.
- TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size());
+ TEST_PCHECK(static_cast<long unsigned int>(
+ write(fd, line.c_str(), line.size())) == line.size());
}),
IsPosixErrorOkAndHolds(0));
}
@@ -233,7 +235,8 @@ TEST_P(ProcSelfUidGidMapTest, NonIdentityMapOwnID) {
EXPECT_THAT(
InNewUserNamespaceWithMapFD([&](int fd) {
DenySelfSetgroups();
- TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size());
+ TEST_PCHECK(static_cast<long unsigned int>(
+ write(fd, line.c_str(), line.size())) == line.size());
}),
IsPosixErrorOkAndHolds(0));
}
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 54709371c..955bcee4b 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -852,6 +852,51 @@ TEST(RawSocketTest, IPv6ProtoRaw) {
SyscallFailsWithErrno(EINVAL));
}
+TEST(RawSocketTest, IPv6SendMsg) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_TCP),
+ SyscallSucceeds());
+
+ char kBuf[] = "hello";
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<char*>(kBuf));
+ iov.iov_len = static_cast<size_t>(sizeof(kBuf));
+
+ struct sockaddr_storage addr = {};
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+ struct msghdr msg = {};
+ msg.msg_name = static_cast<void*>(&addr);
+ msg.msg_namelen = sizeof(sockaddr_in);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(RawSocketTest, ConnectOnIPv6Socket) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_TCP),
+ SyscallSucceeds());
+
+ struct sockaddr_storage addr = {};
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+ ASSERT_THAT(connect(sock, reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(sockaddr_in6)),
+ SyscallFailsWithErrno(EAFNOSUPPORT));
+}
+
INSTANTIATE_TEST_SUITE_P(
AllInetTests, RawSocketTest,
::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 890f4a246..0530fce44 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -20,6 +20,7 @@
#include <atomic>
#include <cerrno>
#include <ctime>
+#include <set>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -35,6 +36,17 @@ namespace gvisor {
namespace testing {
namespace {
+constexpr int kSemMap = 1024000000;
+constexpr int kSemMni = 32000;
+constexpr int kSemMns = 1024000000;
+constexpr int kSemMnu = 1024000000;
+constexpr int kSemMsl = 32000;
+constexpr int kSemOpm = 500;
+constexpr int kSemUme = 500;
+constexpr int kSemUsz = 20;
+constexpr int kSemVmx = 32767;
+constexpr int kSemAem = 32767;
+
class AutoSem {
public:
explicit AutoSem(int id) : id_(id) {}
@@ -586,7 +598,7 @@ TEST(SemaphoreTest, SemopGetzcnt) {
buf.sem_num = 0;
buf.sem_op = 0;
constexpr size_t kLoops = 10;
- for (auto i = 0; i < kLoops; i++) {
+ for (size_t i = 0; i < kLoops; i++) {
auto child_pid = fork();
if (child_pid == 0) {
TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0);
@@ -693,7 +705,7 @@ TEST(SemaphoreTest, SemopGetncnt) {
buf.sem_num = 0;
buf.sem_op = -1;
constexpr size_t kLoops = 10;
- for (auto i = 0; i < kLoops; i++) {
+ for (size_t i = 0; i < kLoops; i++) {
auto child_pid = fork();
if (child_pid == 0) {
TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0);
@@ -773,6 +785,151 @@ TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) {
EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0);
}
+TEST(SemaphoreTest, IpcInfo) {
+ constexpr int kLoops = 5;
+ std::set<int> sem_ids;
+ struct seminfo info;
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ for (int i = 0; i < kLoops; i++) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ sem_ids.insert(sem.release());
+ }
+ ASSERT_EQ(sem_ids.size(), kLoops);
+
+ int max_used_index = 0;
+ EXPECT_THAT(max_used_index = semctl(0, 0, IPC_INFO, &info),
+ SyscallSucceeds());
+
+ std::set<int> sem_ids_before_max_index;
+ for (int i = 0; i <= max_used_index; i++) {
+ struct semid_ds ds = {};
+ int sem_id = semctl(i, 0, SEM_STAT, &ds);
+ // Only if index i is used within the registry.
+ if (sem_ids.find(sem_id) != sem_ids.end()) {
+ struct semid_ds ipc_stat_ds;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds());
+ EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key);
+ EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid);
+ EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid);
+ EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid);
+ EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid);
+ EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode);
+ EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime);
+ EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime);
+ EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems);
+
+ // Remove the semaphore set's read permission.
+ struct semid_ds ipc_set_ds;
+ ipc_set_ds.sem_perm.uid = getuid();
+ ipc_set_ds.sem_perm.gid = getgid();
+ // Keep the semaphore set's write permission so that it could be removed.
+ ipc_set_ds.sem_perm.mode = 0200;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds());
+ ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES));
+
+ sem_ids_before_max_index.insert(sem_id);
+ }
+ }
+ EXPECT_EQ(sem_ids_before_max_index.size(), kLoops);
+ for (const int sem_id : sem_ids) {
+ ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceeds());
+ EXPECT_EQ(info.semmap, kSemMap);
+ EXPECT_EQ(info.semmni, kSemMni);
+ EXPECT_EQ(info.semmns, kSemMns);
+ EXPECT_EQ(info.semmnu, kSemMnu);
+ EXPECT_EQ(info.semmsl, kSemMsl);
+ EXPECT_EQ(info.semopm, kSemOpm);
+ EXPECT_EQ(info.semume, kSemUme);
+ EXPECT_EQ(info.semusz, kSemUsz);
+ EXPECT_EQ(info.semvmx, kSemVmx);
+ EXPECT_EQ(info.semaem, kSemAem);
+}
+
+TEST(SemaphoreTest, SemInfo) {
+ constexpr int kLoops = 5;
+ constexpr int kSemSetSize = 3;
+ std::set<int> sem_ids;
+ struct seminfo info;
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ for (int i = 0; i < kLoops; i++) {
+ AutoSem sem(semget(IPC_PRIVATE, kSemSetSize, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ sem_ids.insert(sem.release());
+ }
+ ASSERT_EQ(sem_ids.size(), kLoops);
+ int max_used_index = 0;
+ EXPECT_THAT(max_used_index = semctl(0, 0, SEM_INFO, &info),
+ SyscallSucceeds());
+ EXPECT_EQ(info.semmap, kSemMap);
+ EXPECT_EQ(info.semmni, kSemMni);
+ EXPECT_EQ(info.semmns, kSemMns);
+ EXPECT_EQ(info.semmnu, kSemMnu);
+ EXPECT_EQ(info.semmsl, kSemMsl);
+ EXPECT_EQ(info.semopm, kSemOpm);
+ EXPECT_EQ(info.semume, kSemUme);
+ // There could be semaphores existing in the system during the test, which
+ // prevents the test from getting a exact number, but the test could expect at
+ // least the number of sempahroes it creates in the begining of the test.
+ EXPECT_GE(info.semusz, sem_ids.size());
+ EXPECT_EQ(info.semvmx, kSemVmx);
+ EXPECT_GE(info.semaem, sem_ids.size() * kSemSetSize);
+
+ std::set<int> sem_ids_before_max_index;
+ for (int i = 0; i <= max_used_index; i++) {
+ struct semid_ds ds = {};
+ int sem_id = semctl(i, 0, SEM_STAT, &ds);
+ // Only if index i is used within the registry.
+ if (sem_ids.find(sem_id) != sem_ids.end()) {
+ struct semid_ds ipc_stat_ds;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds());
+ EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key);
+ EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid);
+ EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid);
+ EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid);
+ EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid);
+ EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode);
+ EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime);
+ EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime);
+ EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems);
+
+ // Remove the semaphore set's read permission.
+ struct semid_ds ipc_set_ds;
+ ipc_set_ds.sem_perm.uid = getuid();
+ ipc_set_ds.sem_perm.gid = getgid();
+ // Keep the semaphore set's write permission so that it could be removed.
+ ipc_set_ds.sem_perm.mode = 0200;
+ ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds());
+ ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES));
+
+ sem_ids_before_max_index.insert(sem_id);
+ }
+ }
+ EXPECT_EQ(sem_ids_before_max_index.size(), kLoops);
+ for (const int sem_id : sem_ids) {
+ ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(semctl(0, 0, SEM_INFO, &info), SyscallSucceeds());
+ EXPECT_EQ(info.semmap, kSemMap);
+ EXPECT_EQ(info.semmni, kSemMni);
+ EXPECT_EQ(info.semmns, kSemMns);
+ EXPECT_EQ(info.semmnu, kSemMnu);
+ EXPECT_EQ(info.semmsl, kSemMsl);
+ EXPECT_EQ(info.semopm, kSemOpm);
+ EXPECT_EQ(info.semume, kSemUme);
+ // Apart from semapahores that are not created by the test, we can't determine
+ // the exact number of semaphore sets and semaphores, as a result, semusz and
+ // semaem range from 0 to a random number. Since the numbers are always
+ // non-negative, the test will not check the reslts of semusz and semaem.
+ EXPECT_EQ(info.semvmx, kSemVmx);
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
index 389e5fca2..c86cd2755 100644
--- a/test/syscalls/linux/signalfd.cc
+++ b/test/syscalls/linux/signalfd.cc
@@ -126,7 +126,7 @@ TEST_P(SignalfdTest, Blocking) {
// Shared tid variable.
absl::Mutex mu;
- bool has_tid;
+ bool has_tid = false;
pid_t tid;
// Start a thread reading.
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index e680d3dd7..32f583581 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -46,7 +46,7 @@ TEST(SocketTest, ProtocolUnix) {
{AF_UNIX, SOCK_SEQPACKET, PF_UNIX},
{AF_UNIX, SOCK_DGRAM, PF_UNIX},
};
- for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
ASSERT_NO_ERRNO_AND_VALUE(
Socket(tests[i].domain, tests[i].type, tests[i].protocol));
}
@@ -59,7 +59,7 @@ TEST(SocketTest, ProtocolInet) {
{AF_INET, SOCK_DGRAM, IPPROTO_UDP},
{AF_INET, SOCK_STREAM, IPPROTO_TCP},
};
- for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
ASSERT_NO_ERRNO_AND_VALUE(
Socket(tests[i].domain, tests[i].type, tests[i].protocol));
}
@@ -87,7 +87,7 @@ TEST(SocketTest, UnixSocketStat) {
ASSERT_THAT(stat(addr.sun_path, &statbuf), SyscallSucceeds());
// Mode should be S_IFSOCK.
- EXPECT_EQ(statbuf.st_mode, S_IFSOCK | sock_perm & ~mask);
+ EXPECT_EQ(statbuf.st_mode, S_IFSOCK | (sock_perm & ~mask));
// Timestamps should be equal and non-zero.
// TODO(b/158882152): Sockets currently don't implement timestamps.
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
index 5ed57625c..f8a0a80f2 100644
--- a/test/syscalls/linux/socket_bind_to_device_distribution.cc
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -168,7 +168,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
std::vector<std::unique_ptr<ScopedThread>> listen_threads(
listener_fds.size());
- for (int i = 0; i < listener_fds.size(); i++) {
+ for (long unsigned int i = 0; i < listener_fds.size(); i++) {
listen_threads[i] = absl::make_unique<ScopedThread>(
[&listener_fds, &accept_counts, &connects_received, i,
kConnectAttempts]() {
@@ -204,7 +204,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
});
}
- for (int i = 0; i < kConnectAttempts; i++) {
+ for (int32_t i = 0; i < kConnectAttempts; i++) {
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
@@ -212,22 +212,8 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
connector.addr_len),
SyscallSucceeds());
- // Do two separate sends to ensure two segments are received. This is
- // required for netstack where read is incorrectly assuming a whole
- // segment is read when endpoint.Read() is called which is technically
- // incorrect as the syscall that invoked endpoint.Read() may only
- // consume it partially. This results in a case where a close() of
- // such a socket does not trigger a RST in netstack due to the
- // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
-
- // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
- // generates a RST.
- if (IsRunningOnGvisor()) {
- EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
- SyscallSucceedsWithValue(sizeof(i)));
- }
}
// Join threads to be sure that all connections have been counted.
@@ -235,7 +221,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
listen_thread->Join();
}
// Check that connections are distributed correctly among listening sockets.
- for (int i = 0; i < accept_counts.size(); i++) {
+ for (long unsigned int i = 0; i < accept_counts.size(); i++) {
EXPECT_THAT(
accept_counts[i],
EquivalentWithin(static_cast<int>(kConnectAttempts *
@@ -308,7 +294,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
listener_fds.size());
- for (int i = 0; i < listener_fds.size(); i++) {
+ for (long unsigned int i = 0; i < listener_fds.size(); i++) {
receiver_threads[i] = absl::make_unique<ScopedThread>(
[&listener_fds, &packets_per_socket, &packets_received, i]() {
do {
@@ -366,7 +352,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
receiver_thread->Join();
}
// Check that packets are distributed correctly among listening sockets.
- for (int i = 0; i < packets_per_socket.size(); i++) {
+ for (long unsigned int i = 0; i < packets_per_socket.size(); i++) {
EXPECT_THAT(
packets_per_socket[i],
EquivalentWithin(static_cast<int>(kConnectAttempts *
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index 70cc86b16..de0b8bb11 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -43,6 +43,15 @@ TEST_P(AllSocketPairTest, BasicReadWrite) {
EXPECT_EQ(data, absl::string_view(buf, 3));
}
+TEST_P(AllSocketPairTest, BasicReadWriteBadBuffer) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ const std::string data = "abc";
+ ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), 3),
+ SyscallSucceedsWithValue(3));
+ ASSERT_THAT(ReadFd(sockets->second_fd(), nullptr, 3),
+ SyscallFailsWithErrno(EFAULT));
+}
+
TEST_P(AllSocketPairTest, BasicSendRecv) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[512];
@@ -853,5 +862,21 @@ TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) {
}
}
+TEST_P(AllSocketPairTest, GetSocketOutOfBandInlineOption) {
+ // We do not support disabling this option. It is always enabled.
+ SKIP_IF(!IsRunningOnGvisor());
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int enable = -1;
+ socklen_t enableLen = sizeof(enable);
+
+ int want = 1;
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_OOBINLINE, &enable,
+ &enableLen),
+ SyscallSucceeds());
+ ASSERT_EQ(enableLen, sizeof(enable));
+ EXPECT_EQ(enable, want);
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 51b77ad85..a11147085 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -1507,7 +1507,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
}
ScopedThread connecting_thread([&connector, &conn_addr]() {
- for (int i = 0; i < kConnectAttempts; i++) {
+ for (int32_t i = 0; i < kConnectAttempts; i++) {
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
@@ -1515,22 +1515,8 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
connector.addr_len),
SyscallSucceeds());
- // Do two separate sends to ensure two segments are received. This is
- // required for netstack where read is incorrectly assuming a whole
- // segment is read when endpoint.Read() is called which is technically
- // incorrect as the syscall that invoked endpoint.Read() may only
- // consume it partially. This results in a case where a close() of
- // such a socket does not trigger a RST in netstack due to the
- // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
-
- // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
- // generates a RST.
- if (IsRunningOnGvisor()) {
- EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
- SyscallSucceedsWithValue(sizeof(i)));
- }
}
});
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index f69f8f99f..1694e188a 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -15,6 +15,9 @@
#include "test/syscalls/linux/socket_ip_udp_generic.h"
#include <errno.h>
+#ifdef __linux__
+#include <linux/in6.h>
+#endif // __linux__
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
@@ -356,6 +359,58 @@ TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) {
EXPECT_EQ(get_len, sizeof(get));
}
+// Test getsockopt for a socket which is not set with IP_RECVORIGDSTADDR option.
+TEST_P(UDPSocketPairTest, ReceiveOrigDstAddrDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ int level = SOL_IP;
+ int type = IP_RECVORIGDSTADDR;
+ if (sockets->first_addr()->sa_family == AF_INET6) {
+ level = SOL_IPV6;
+ type = IPV6_RECVORIGDSTADDR;
+ }
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test setsockopt and getsockopt for a socket with IP_RECVORIGDSTADDR option.
+TEST_P(UDPSocketPairTest, SetAndGetReceiveOrigDstAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int level = SOL_IP;
+ int type = IP_RECVORIGDSTADDR;
+ if (sockets->first_addr()->sa_family == AF_INET6) {
+ level = SOL_IPV6;
+ type = IPV6_RECVORIGDSTADDR;
+ }
+
+ // Check getsockopt before IP_PKTINFO is set.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_len, sizeof(get));
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_len, sizeof(get));
+}
+
// Holds TOS or TClass information for IPv4 or IPv6 respectively.
struct RecvTosOption {
int level;
@@ -438,7 +493,7 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) {
// This should only test AF_INET6 sockets for the mismatch behavior.
SKIP_IF(GetParam().domain != AF_INET6);
// IPV6_RECVTCLASS is only valid for SOCK_DGRAM and SOCK_RAW.
- SKIP_IF(GetParam().type != SOCK_DGRAM | GetParam().type != SOCK_RAW);
+ SKIP_IF((GetParam().type != SOCK_DGRAM) | (GetParam().type != SOCK_RAW));
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc
new file mode 100644
index 000000000..fdbb2216b
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ip_udp_unbound_external_networking.h"
+
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+void IPUDPUnboundExternalNetworkingSocketTest::SetUp() {
+ // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
+ // IPv4 address on eth0.
+ found_net_interfaces_ = false;
+
+ // Get interface list.
+ ASSERT_NO_ERRNO(if_helper_.Load());
+ std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET);
+ if (if_names.size() != 2) {
+ return;
+ }
+
+ // Figure out which interface is where.
+ std::string lo = if_names[0];
+ std::string eth = if_names[1];
+ if (lo != "lo") std::swap(lo, eth);
+ if (lo != "lo") return;
+
+ lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo));
+ auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo);
+ if (lo_if_addr == nullptr) {
+ return;
+ }
+ lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr);
+
+ eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth));
+ auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth);
+ if (eth_if_addr == nullptr) {
+ return;
+ }
+ eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr);
+
+ found_net_interfaces_ = true;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ip_udp_unbound_external_networking.h
new file mode 100644
index 000000000..e5287addb
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_udp_unbound_external_networking.h
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to unbound IP UDP sockets in a sandbox
+// with external networking support.
+class IPUDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {
+ protected:
+ void SetUp() override;
+
+ IfAddrHelper if_helper_;
+
+ // found_net_interfaces_ is set to false if SetUp() could not obtain
+ // all interface infos that we need.
+ bool found_net_interfaces_;
+
+ // Interface infos.
+ int lo_if_idx_;
+ int eth_if_idx_;
+ sockaddr_in lo_if_addr_;
+ sockaddr_in eth_if_addr_;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index b3f54e7f6..e557572a7 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -2222,6 +2222,90 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK));
}
+// Test that socket will receive IP_RECVORIGDSTADDR control message.
+TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver_addr = V4Loopback();
+ int level = SOL_IP;
+ int type = IP_RECVORIGDSTADDR;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Retrieve the port bound by the receiver.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Get address and port bound by the sender.
+ sockaddr_storage sender_addr_storage;
+ socklen_t sender_addr_len = sizeof(sender_addr_storage);
+ ASSERT_THAT(getsockname(sender->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr_storage),
+ &sender_addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in));
+
+ // Enable IP_RECVORIGDSTADDR on socket so that we get the original destination
+ // address of the datagram as auxiliary information in the control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(sockaddr_in))] = {};
+ size_t cmsg_data_len = sizeof(sockaddr_in);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Check the data
+ sockaddr_in received_addr = {};
+ memcpy(&received_addr, CMSG_DATA(cmsg), sizeof(received_addr));
+ auto orig_receiver_addr = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr);
+ EXPECT_EQ(received_addr.sin_addr.s_addr, orig_receiver_addr->sin_addr.s_addr);
+ EXPECT_EQ(received_addr.sin_port, orig_receiver_addr->sin_port);
+}
+
// Check that setting SO_RCVBUF below min is clamped to the minimum
// receive buffer size.
TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufBelowMin) {
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index 2eecb0866..940289d15 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -14,23 +14,6 @@
#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h"
-#include <arpa/inet.h>
-#include <ifaddrs.h>
-#include <netinet/in.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <sys/un.h>
-
-#include <cstdint>
-#include <cstdio>
-#include <cstring>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "test/syscalls/linux/ip_socket_test_util.h"
-#include "test/syscalls/linux/socket_test_util.h"
-#include "test/util/test_util.h"
-
namespace gvisor {
namespace testing {
@@ -41,41 +24,6 @@ TestAddress V4EmptyAddress() {
return t;
}
-void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
- // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
- // IPv4 address on eth0.
- found_net_interfaces_ = false;
-
- // Get interface list.
- ASSERT_NO_ERRNO(if_helper_.Load());
- std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET);
- if (if_names.size() != 2) {
- return;
- }
-
- // Figure out which interface is where.
- std::string lo = if_names[0];
- std::string eth = if_names[1];
- if (lo != "lo") std::swap(lo, eth);
- if (lo != "lo") return;
-
- lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo));
- auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo);
- if (lo_if_addr == nullptr) {
- return;
- }
- lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr);
-
- eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth));
- auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth);
- if (eth_if_addr == nullptr) {
- return;
- }
- eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr);
-
- found_net_interfaces_ = true;
-}
-
// Verifies that a broadcast UDP packet will arrive at all UDP sockets with
// the destination port number.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
index 0e9e70e8e..20922ac1f 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
@@ -15,30 +15,15 @@
#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
-#include "test/syscalls/linux/ip_socket_test_util.h"
-#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/socket_ip_udp_unbound_external_networking.h"
namespace gvisor {
namespace testing {
// Test fixture for tests that apply to unbound IPv4 UDP sockets in a sandbox
// with external networking support.
-class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {
- protected:
- void SetUp();
-
- IfAddrHelper if_helper_;
-
- // found_net_interfaces_ is set to false if SetUp() could not obtain
- // all interface infos that we need.
- bool found_net_interfaces_;
-
- // Interface infos.
- int lo_if_idx_;
- int eth_if_idx_;
- sockaddr_in lo_if_addr_;
- sockaddr_in eth_if_addr_;
-};
+using IPv4UDPUnboundExternalNetworkingSocketTest =
+ IPUDPUnboundExternalNetworkingSocketTest;
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
index 875016812..9a9ddc297 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
@@ -177,7 +177,7 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) {
// Broadcasts from each socket should be received by every socket (including
// the sending socket).
- for (int w = 0; w < socks.size(); w++) {
+ for (long unsigned int w = 0; w < socks.size(); w++) {
auto& w_sock = socks[w];
ASSERT_THAT(
RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0,
@@ -187,7 +187,7 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) {
<< "write socks[" << w << "]";
// Check that we received the packet on all sockets.
- for (int r = 0; r < socks.size(); r++) {
+ for (long unsigned int r = 0; r < socks.size(); r++) {
auto& r_sock = socks[r];
struct pollfd poll_fd = {r_sock->get(), POLLIN, 0};
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.cc b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
new file mode 100644
index 000000000..08526468e
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv6_udp_unbound.h"
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#ifdef __linux__
+#include <linux/in6.h>
+#endif // __linux__
+#include <net/if.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test that socket will receive IP_RECVORIGDSTADDR control message.
+TEST_P(IPv6UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver_addr = V6Loopback();
+ int level = SOL_IPV6;
+ int type = IPV6_RECVORIGDSTADDR;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Retrieve the port bound by the receiver.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Get address and port bound by the sender.
+ sockaddr_storage sender_addr_storage;
+ socklen_t sender_addr_len = sizeof(sender_addr_storage);
+ ASSERT_THAT(getsockname(sender->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr_storage),
+ &sender_addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in6));
+
+ // Enable IP_RECVORIGDSTADDR on socket so that we get the original destination
+ // address of the datagram as auxiliary information in the control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(sockaddr_in6))] = {};
+ size_t cmsg_data_len = sizeof(sockaddr_in6);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Check that the received address in the control message matches the expected
+ // receiver's address.
+ sockaddr_in6 received_addr = {};
+ memcpy(&received_addr, CMSG_DATA(cmsg), sizeof(received_addr));
+ auto orig_receiver_addr =
+ reinterpret_cast<sockaddr_in6*>(&receiver_addr.addr);
+ EXPECT_EQ(memcmp(&received_addr.sin6_addr, &orig_receiver_addr->sin6_addr,
+ sizeof(in6_addr)),
+ 0);
+ EXPECT_EQ(received_addr.sin6_port, orig_receiver_addr->sin6_port);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.h b/test/syscalls/linux/socket_ipv6_udp_unbound.h
new file mode 100644
index 000000000..71e160f73
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound.h
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to IPv6 UDP sockets.
+using IPv6UDPUnboundSocketTest = SimpleSocketTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
new file mode 100644
index 000000000..7364a1ea5
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
+ SKIP_IF(!found_net_interfaces_);
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ auto receiver_addr = V6Any();
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ auto multicast_addr = V6Multicast();
+ ipv6_mreq group_req = {
+ .ipv6mr_multiaddr =
+ reinterpret_cast<sockaddr_in6*>(&multicast_addr.addr)->sin6_addr,
+ .ipv6mr_interface =
+ (unsigned int)ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")),
+ };
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
+ &group_req, sizeof(group_req)),
+ SyscallSucceeds());
+
+ // Set the sender to the loopback interface.
+ auto sender_addr = V6Loopback();
+ ASSERT_THAT(
+ bind(sender->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = multicast_addr;
+ reinterpret_cast<sockaddr_in6*>(&send_addr.addr)->sin6_port =
+ reinterpret_cast<sockaddr_in6*>(&receiver_addr.addr)->sin6_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+
+ // Leave the group and make sure we don't receive its multicast traffic.
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IPV6, IPV6_DROP_MEMBERSHIP,
+ &group_req, sizeof(group_req)),
+ SyscallSucceeds());
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h
new file mode 100644
index 000000000..731ae0a1f
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
+
+#include "test/syscalls/linux/socket_ip_udp_unbound_external_networking.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to unbound IPv6 UDP sockets in a sandbox
+// with external networking support.
+using IPv6UDPUnboundExternalNetworkingSocketTest =
+ IPUDPUnboundExternalNetworkingSocketTest;
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6yy_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking_test.cc
new file mode 100644
index 000000000..5c764b8fd
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking_test.cc
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h"
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+std::vector<SocketKind> GetSockets() {
+ return ApplyVec<SocketKind>(
+ IPv6UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK}));
+}
+
+INSTANTIATE_TEST_SUITE_P(IPv6UDPUnboundSockets,
+ IPv6UDPUnboundExternalNetworkingSocketTest,
+ ::testing::ValuesIn(GetSockets()));
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc
new file mode 100644
index 000000000..058336ecc
--- /dev/null
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_ipv6_udp_unbound.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+INSTANTIATE_TEST_SUITE_P(
+ IPv6UDPSockets, IPv6UDPUnboundSocketTest,
+ ::testing::ValuesIn(ApplyVec<SocketKind>(IPv6UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{
+ 0, SOCK_NONBLOCK}))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index a760581b5..26dacc95e 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -860,6 +860,17 @@ TestAddress V6Loopback() {
return t;
}
+TestAddress V6Multicast() {
+ TestAddress t("V6Multicast");
+ t.addr.ss_family = AF_INET6;
+ t.addr_len = sizeof(sockaddr_in6);
+ EXPECT_EQ(
+ 1,
+ inet_pton(AF_INET6, "ff05::1234",
+ reinterpret_cast<sockaddr_in6*>(&t.addr)->sin6_addr.s6_addr));
+ return t;
+}
+
// Checksum computes the internet checksum of a buffer.
uint16_t Checksum(uint16_t* buf, ssize_t buf_size) {
// Add up the 16-bit values in the buffer.
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 5e205339f..75c0d4735 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -502,6 +502,7 @@ TestAddress V4MappedLoopback();
TestAddress V4Multicast();
TestAddress V6Any();
TestAddress V6Loopback();
+TestAddress V6Multicast();
// Compute the internet checksum of an IP header.
uint16_t IPChecksum(struct iphdr ip);
diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc
index cab912152..a035fb095 100644
--- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc
+++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <fcntl.h>
#include <stdio.h>
#include <sys/un.h>
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -70,6 +72,20 @@ TEST_P(UnboundFilesystemUnixSocketPairTest, GetSockNameLength) {
strlen(want_addr.sun_path) + 1 + sizeof(want_addr.sun_family));
}
+TEST_P(UnboundFilesystemUnixSocketPairTest, OpenSocketWithTruncate) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ const struct sockaddr_un *addr =
+ reinterpret_cast<const struct sockaddr_un *>(sockets->first_addr());
+ EXPECT_THAT(chmod(addr->sun_path, 0777), SyscallSucceeds());
+ EXPECT_THAT(open(addr->sun_path, O_RDONLY | O_TRUNC),
+ SyscallFailsWithErrno(ENXIO));
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, UnboundFilesystemUnixSocketPairTest,
::testing::ValuesIn(ApplyVec<SocketPairKind>(
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
index 97d554e72..538652183 100644
--- a/test/syscalls/linux/tuntap.cc
+++ b/test/syscalls/linux/tuntap.cc
@@ -162,12 +162,19 @@ TEST(TuntapStaticTest, NetTunExists) {
class TuntapTest : public ::testing::Test {
protected:
+ void SetUp() override {
+ have_net_admin_cap_ =
+ ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN));
+ }
+
void TearDown() override {
- if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) {
+ if (have_net_admin_cap_) {
// Bring back capability if we had dropped it in test case.
ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true));
}
}
+
+ bool have_net_admin_cap_;
};
TEST_F(TuntapTest, CreateInterfaceNoCap) {
@@ -324,8 +331,9 @@ TEST_F(TuntapTest, PingKernel) {
};
while (1) {
inpkt r = {};
- int n = read(fd.get(), &r, sizeof(r));
- EXPECT_THAT(n, SyscallSucceeds());
+ int nread = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(nread, SyscallSucceeds());
+ long unsigned int n = static_cast<long unsigned int>(nread);
if (n < sizeof(pihdr)) {
std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
@@ -383,8 +391,9 @@ TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
};
while (1) {
inpkt r = {};
- int n = read(fd.get(), &r, sizeof(r));
- EXPECT_THAT(n, SyscallSucceeds());
+ int nread = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(nread, SyscallSucceeds());
+ long unsigned int n = static_cast<long unsigned int>(nread);
if (n < sizeof(pihdr)) {
std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 34255bfb8..650f12350 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -14,6 +14,8 @@
#include <arpa/inet.h>
#include <fcntl.h>
+#include <netinet/icmp6.h>
+#include <netinet/ip_icmp.h>
#include <ctime>
@@ -375,8 +377,6 @@ TEST_P(UdpSocketTest, BindInUse) {
}
TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
- ASSERT_NO_ERRNO(BindLoopback());
-
// Discover a free unused port by creating a new UDP socket, binding it
// recording the just bound port and closing it. This is not guaranteed as it
// can still race with other port UDP sockets trying to bind a port at the
@@ -410,6 +410,35 @@ TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
ASSERT_EQ(optlen, sizeof(err));
}
+TEST_P(UdpSocketTest, ConnectSimultaneousWriteToInvalidPort) {
+ // Discover a free unused port by creating a new UDP socket, binding it
+ // recording the just bound port and closing it. This is not guaranteed as it
+ // can still race with other port UDP sockets trying to bind a port at the
+ // same time.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ socklen_t addrlen = sizeof(addr_storage);
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds());
+ ASSERT_THAT(getsockname(s.get(), addr, &addrlen), SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr_storage), 0);
+ ASSERT_THAT(close(s.release()), SyscallSucceeds());
+
+ // Now connect to the port that we just released.
+ ScopedThread t([&] {
+ ASSERT_THAT(connect(sock_.get(), addr, addrlen_), SyscallSucceeds());
+ });
+
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+ // Send from sock_ to an unbound port.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ t.Join();
+}
+
TEST_P(UdpSocketTest, ReceiveAfterConnect) {
ASSERT_NO_ERRNO(BindLoopback());
ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
@@ -752,6 +781,99 @@ TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) {
SyscallFailsWithErrno(ECONNREFUSED));
}
+#ifdef __linux__
+TEST_P(UdpSocketTest, RecvErrorConnRefused) {
+ // We will simulate an ICMP error and verify that we do receive that error via
+ // recvmsg(MSG_ERRQUEUE).
+ ASSERT_NO_ERRNO(BindLoopback());
+ // Close the socket to release the port so that we get an ICMP error.
+ ASSERT_THAT(close(bind_.release()), SyscallSucceeds());
+
+ // Set IP_RECVERR socket option to enable error queueing.
+ int v = kSockOptOn;
+ socklen_t optlen = sizeof(v);
+ int opt_level = SOL_IP;
+ int opt_type = IP_RECVERR;
+ if (GetParam() != AddressFamily::kIpv4) {
+ opt_level = SOL_IPV6;
+ opt_type = IPV6_RECVERR;
+ }
+ ASSERT_THAT(setsockopt(sock_.get(), opt_level, opt_type, &v, optlen),
+ SyscallSucceeds());
+
+ // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an
+ // UDP socket. There is no easy way to ensure that the UDP port is not bound
+ // by another conncurrently running test. *This is potentially flaky*.
+ const int kBufLen = 300;
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+ char buf[kBufLen];
+ RandomizeBuffer(buf, sizeof(buf));
+ // Send from sock_ to an unbound port. This should cause ECONNREFUSED.
+ EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Dequeue error using recvmsg(MSG_ERRQUEUE).
+ char got[kBufLen];
+ struct iovec iov;
+ iov.iov_base = reinterpret_cast<void*>(got);
+ iov.iov_len = kBufLen;
+
+ size_t control_buf_len = CMSG_SPACE(sizeof(sock_extended_err) + addrlen_);
+ char* control_buf = static_cast<char*>(calloc(1, control_buf_len));
+ struct sockaddr_storage remote;
+ memset(&remote, 0, sizeof(remote));
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_flags = 0;
+ msg.msg_control = control_buf;
+ msg.msg_controllen = control_buf_len;
+ msg.msg_name = reinterpret_cast<void*>(&remote);
+ msg.msg_namelen = addrlen_;
+ ASSERT_THAT(recvmsg(sock_.get(), &msg, MSG_ERRQUEUE),
+ SyscallSucceedsWithValue(kBufLen));
+
+ // Check the contents of msg.
+ EXPECT_EQ(memcmp(got, buf, sizeof(buf)), 0); // iovec check
+ // TODO(b/176251997): The next check fails on the gvisor platform due to the
+ // kernel bug.
+ if (!IsRunningWithHostinet() || GvisorPlatform() == Platform::kPtrace ||
+ GvisorPlatform() == Platform::kKVM ||
+ GvisorPlatform() == Platform::kNative)
+ EXPECT_NE(msg.msg_flags & MSG_ERRQUEUE, 0);
+ EXPECT_EQ(memcmp(&remote, bind_addr_, addrlen_), 0);
+
+ // Check the contents of the control message.
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(CMSG_NXTHDR(&msg, cmsg), nullptr);
+ EXPECT_EQ(cmsg->cmsg_level, opt_level);
+ EXPECT_EQ(cmsg->cmsg_type, opt_type);
+
+ // Check the contents of socket error.
+ struct sock_extended_err* sock_err =
+ (struct sock_extended_err*)CMSG_DATA(cmsg);
+ EXPECT_EQ(sock_err->ee_errno, ECONNREFUSED);
+ if (GetParam() == AddressFamily::kIpv4) {
+ EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_ICMP);
+ EXPECT_EQ(sock_err->ee_type, ICMP_DEST_UNREACH);
+ EXPECT_EQ(sock_err->ee_code, ICMP_PORT_UNREACH);
+ } else {
+ EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_ICMP6);
+ EXPECT_EQ(sock_err->ee_type, ICMP6_DST_UNREACH);
+ EXPECT_EQ(sock_err->ee_code, ICMP6_DST_UNREACH_NOPORT);
+ }
+
+ // Now verify that the socket error was cleared by recvmsg(MSG_ERRQUEUE).
+ int err;
+ optlen = sizeof(err);
+ ASSERT_THAT(getsockopt(sock_.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(err, 0);
+ ASSERT_EQ(optlen, sizeof(err));
+}
+#endif // __linux__
+
TEST_P(UdpSocketTest, ZerolengthWriteAllowed) {
// TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes.
SKIP_IF(IsRunningWithHostinet());
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 9063eebaf..7e06d09be 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -14,49 +14,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Make hacks.
-EMPTY :=
-SPACE := $(EMPTY) $(EMPTY)
+##
+## Docker options.
+##
+## This file supports targets that wrap bazel in a running Docker
+## container to simplify development. Some options are available to
+## control the behavior of this container:
+##
+## USER - The in-container user.
+## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests).
+## DOCKER_NAME - The container name (default: gvisor-bazel-HASH).
+## DOCKER_PRIVILEGED - Docker privileged flags (default: --privileged).
+## BAZEL_CACHE - The bazel cache directory (default: detected).
+## GCLOUD_CONFIG - The gcloud config directory (detect: detected).
+## DOCKER_SOCKET - The Docker socket (default: detected).
+##
+## To opt out of these wrappers, set DOCKER_BUILD=false.
+DOCKER_BUILD := true
+ifeq ($(DOCKER_BUILD),true)
+-include bazel-server
+endif
# See base Makefile.
-SHELL=/bin/bash -o pipefail
BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
- git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
- xargs -n 1 basename 2>/dev/null)
+ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
+ xargs -n 1 basename 2>/dev/null)
BUILD_ROOTS := bazel-bin/ bazel-out/
# Bazel container configuration (see below).
USER := $(shell whoami)
HASH := $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8)
-BUILDER_BASE := gvisor.dev/images/default
-BUILDER_IMAGE := gvisor.dev/images/builder
-BUILDER_NAME := gvisor-builder-$(HASH)
-DOCKER_NAME := gvisor-bazel-$(HASH)
+BUILDER_NAME := gvisor-builder-$(HASH)-$(ARCH)
+DOCKER_NAME := gvisor-bazel-$(HASH)-$(ARCH)
DOCKER_PRIVILEGED := --privileged
-BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/)
-GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/)
+BAZEL_CACHE := $(HOME)/.cache/bazel/
+GCLOUD_CONFIG := $(HOME)/.config/gcloud/
DOCKER_SOCKET := /var/run/docker.sock
DOCKER_CONFIG := /etc/docker
-# Bazel flags.
-BAZEL := bazel $(STARTUP_OPTIONS)
-BASE_OPTIONS := --color=no --curses=no
+##
+## Bazel helpers.
+##
+## Bazel will be run with standard flags. You can specify the following flags
+## to control which flags are passed:
+##
+## STARTUP_OPTIONS - Startup options passed to Bazel.
+##
+STARTUP_OPTIONS :=
+BAZEL_OPTIONS :=
+BAZEL := bazel $(STARTUP_OPTIONS)
+BASE_OPTIONS := --color=no --curses=no
+TEST_OPTIONS := $(BASE_OPTIONS) \
+ --test_output=errors \
+ --keep_going \
+ --verbose_failures=true \
+ --build_event_json_file=.build_events.json
# Basic options.
UID := $(shell id -u ${USER})
GID := $(shell id -g ${USER})
USERADD_OPTIONS :=
-FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS)
-FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID)
-FULL_DOCKER_RUN_OPTIONS += --entrypoint ""
-FULL_DOCKER_RUN_OPTIONS += --init
-FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)"
-FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)"
-FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp"
-FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID)
-FULL_DOCKER_EXEC_OPTIONS += --interactive
-ifeq (true,$(shell [[ -t 0 ]] && echo true))
-FULL_DOCKER_EXEC_OPTIONS += --tty
+DOCKER_RUN_OPTIONS :=
+DOCKER_RUN_OPTIONS += --user $(UID):$(GID)
+DOCKER_RUN_OPTIONS += --entrypoint ""
+DOCKER_RUN_OPTIONS += --init
+DOCKER_RUN_OPTIONS += -v "$(shell readlink -m $(BAZEL_CACHE)):$(BAZEL_CACHE)"
+DOCKER_RUN_OPTIONS += -v "$(shell readlink -m $(GCLOUD_CONFIG)):$(GCLOUD_CONFIG)"
+DOCKER_RUN_OPTIONS += -v "/tmp:/tmp"
+DOCKER_EXEC_OPTIONS := --user $(UID):$(GID)
+DOCKER_EXEC_OPTIONS += --interactive
+ifeq (true,$(shell test -t 0 && echo true))
+DOCKER_EXEC_OPTIONS += --tty
endif
# Add basic UID/GID options.
@@ -80,91 +108,75 @@ endif
# Add docker passthrough options.
ifneq ($(DOCKER_PRIVILEGED),)
-FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
-FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)"
-FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED)
-FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED)
+DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
+DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)"
+DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED)
+DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED)
DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET))
ifneq ($(GID),$(DOCKER_GROUP))
USERADD_OPTIONS += --groups $(DOCKER_GROUP)
GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) &&
-FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP)
+DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP)
endif
endif
# Add KVM passthrough options.
ifneq (,$(wildcard /dev/kvm))
-FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm
+DOCKER_RUN_OPTIONS += --device=/dev/kvm
KVM_GROUP := $(shell stat -c '%g' /dev/kvm)
ifneq ($(GID),$(KVM_GROUP))
USERADD_OPTIONS += --groups $(KVM_GROUP)
GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) &&
-FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP)
+DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP)
endif
endif
-# Load the appropriate config.
-ifneq (,$(BAZEL_CONFIG))
-OPTIONS += --config=$(BAZEL_CONFIG)
+# Top-level functions.
+#
+# This command runs a bazel server, and the container sticks around
+# until the bazel server exits. This should ensure that it does not
+# exit in the middle of running a build, but also it won't stick around
+# forever. The build commands wrap around an appropriate exec into the
+# container in order to perform work via the bazel client.
+ifeq ($(DOCKER_BUILD),true)
+wrapper = docker exec $(DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(1)
+else
+wrapper = $(1)
endif
-bazel-image: load-default
- @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi
- docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \
- $(BUILDER_BASE) \
- sh -c "$(GROUPADD_DOCKER) \
- $(USERADD_DOCKER) \
- if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi"
- docker commit $(BUILDER_NAME) $(BUILDER_IMAGE)
- @docker rm -f $(BUILDER_NAME)
-.PHONY: bazel-image
-
-##
-## Bazel helpers.
-##
-## This file supports targets that wrap bazel in a running Docker
-## container to simplify development. Some options are available to
-## control the behavior of this container:
-## USER - The in-container user.
-## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests).
-## DOCKER_NAME - The container name (default: gvisor-bazel-HASH).
-## BAZEL_CACHE - The bazel cache directory (default: detected).
-## GCLOUD_CONFIG - The gcloud config directory (detect: detected).
-## DOCKER_SOCKET - The Docker socket (default: detected).
-##
-bazel-server-start: bazel-image ## Starts the bazel server.
- @mkdir -p $(BAZEL_CACHE)
- @mkdir -p $(GCLOUD_CONFIG)
- @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi
- # This command runs a bazel server, and the container sticks around
- # until the bazel server exits. This should ensure that it does not
- # exit in the middle of running a build, but also it won't stick around
- # forever. The build commands wrap around an appropriate exec into the
- # container in order to perform work via the bazel client.
- docker run -d --rm --name $(DOCKER_NAME) \
- -v "$(CURDIR):$(CURDIR)" \
- --workdir "$(CURDIR)" \
- $(FULL_DOCKER_RUN_OPTIONS) \
- $(BUILDER_IMAGE) \
- sh -c "tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null"
-.PHONY: bazel-server-start
-
bazel-shutdown: ## Shuts down a running bazel server.
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \
- rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]]
+ @$(call wrapper,$(BAZEL) shutdown)
.PHONY: bazel-shutdown
bazel-alias: ## Emits an alias that can be used within the shell.
- @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'"
+ @echo "alias bazel='$(call wrapper,$(BAZEL))'"
.PHONY: bazel-alias
-bazel-server: ## Ensures that the server exists. Used as an internal target.
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true >&2 || $(MAKE) bazel-server-start >&2
-.PHONY: bazel-server
+bazel-image: load-default ## Ensures that the local builder exists.
+ @$(call header,DOCKER BUILD)
+ @docker rm -f $(BUILDER_NAME) 2>/dev/null || true
+ @docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) gvisor.dev/images/default \
+ bash -c "$(GROUPADD_DOCKER) $(USERADD_DOCKER) if test -e /dev/kvm; then chmod a+rw /dev/kvm; fi" >&2
+ @docker commit $(BUILDER_NAME) gvisor.dev/images/builder >&2
+.PHONY: bazel-image
-# build_cmd builds the given targets in the bazel-server container.
-build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c \
- '$(BAZEL) build $(BASE_OPTIONS) $(OPTIONS) "$(TARGETS)"'
+ifneq (true,$(shell $(wrapper echo true)))
+bazel-server: bazel-image ## Ensures that the server exists.
+ @$(call header,DOCKER RUN)
+ @docker rm -f $(DOCKER_NAME) 2>/dev/null || true
+ @mkdir -p $(GCLOUD_CONFIG)
+ @mkdir -p $(BAZEL_CACHE)
+ @docker run -d --rm --name $(DOCKER_NAME) \
+ -v "$(CURDIR):$(CURDIR)" \
+ --workdir "$(CURDIR)" \
+ $(DOCKER_RUN_OPTIONS) \
+ gvisor.dev/images/builder \
+ bash -c "set -x; tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null" >&2
+else
+bazel-server:
+ @
+endif
+.PHONY: bazel-server
# build_paths extracts the built binary from the bazel stderr output.
#
@@ -174,49 +186,33 @@ build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefai
# command here? Yikes, let's just stick with the ugly shell pipeline.
#
# The last line is used to prevent terminal shenanigans.
-build_paths = command_line=$$( $(build_cmd) 2>&1 \
- | grep -A1 -E '^Target' \
- | grep -E '^ ($(subst $(SPACE),|,$(BUILD_ROOTS)))' \
- | sed "s/ /\n/g" \
- | strings -n 10 \
- | awk '{$$1=$$1};1' \
- | xargs -n 1 -I {} readlink -f "{}" \
- | xargs -n 1 -I {} echo "$(1)" ) && \
- (set -xeuo pipefail; eval $${command_line})
-
-build: bazel-server
- @$(call build_cmd)
-.PHONY: build
-
-copy: bazel-server
-ifeq (,$(DESTINATION))
- $(error Destination not provided.)
-endif
- @$(call build_paths,cp -fa {} $(DESTINATION))
-
-run: bazel-server
- @$(call build_paths,{} $(ARGS))
-.PHONY: run
-
-sudo: bazel-server
- @$(call build_paths,sudo -E {} $(ARGS))
-.PHONY: sudo
-
-test: bazel-server
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) \
- $(BAZEL) test $(BASE_OPTIONS) \
- --test_output=errors --keep_going --verbose_failures=true \
- --build_event_json_file=.build_events.json \
- $(OPTIONS) $(TARGETS)
-.PHONY: test
-
-testlogs:
- @cat .build_events.json | jq -r \
- 'select(.testSummary?.overallStatus? | tostring | test("(FAILED|FLAKY|TIMEOUT)")) | .testSummary.failed | .[] | .uri' | \
- awk -Ffile:// '{print $$2;}'
+build_paths = \
+ (set -euo pipefail; \
+ $(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1)) 2>&1 \
+ | tee /proc/self/fd/2 \
+ | sed -n -e '/^Target/,$$p' \
+ | sed -n -e '/^ \($(subst /,\/,$(subst $(SPACE),\|,$(BUILD_ROOTS)))\)/p' \
+ | sed -e 's/ /\n/g' \
+ | awk '{$$1=$$1};1' \
+ | strings \
+ | xargs -r -n 1 -I {} readlink -f "{}" \
+ | xargs -r -n 1 -I {} bash -c 'set -xeuo pipefail; $(2)')
+
+clean = $(call header,CLEAN) && $(call wrapper,$(BAZEL) clean)
+build = $(call header,BUILD $(1)) && $(call build_paths,$(1),echo {})
+copy = $(call header,COPY $(1) $(2)) && $(call build_paths,$(1),cp -fa {} $(2))
+run = $(call header,RUN $(1) $(2)) && $(call build_paths,$(1),{} $(2))
+sudo = $(call header,SUDO $(1) $(2)) && $(call build_paths,$(1),sudo -E {} $(2))
+test = $(call header,TEST $(1)) && $(call wrapper,$(BAZEL) test $(TEST_OPTIONS) $(1))
+
+clean: ## Cleans the bazel cache.
+ @$(call clean)
+.PHONY: clean
+
+testlogs: ## Returns the most recent set of test logs.
+ @if test -f .build_events.json; then \
+ cat .build_events.json | jq -r \
+ 'select(.testSummary?.overallStatus? | tostring | test("(FAILED|FLAKY|TIMEOUT)")) | "\(.id.testSummary.label) \(.testSummary.failed[].uri)"' | \
+ sed -e 's|file://||'; \
+ fi
.PHONY: testlogs
-
-query: bazel-server
- @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c \
- '$(BAZEL) query $(BASE_OPTIONS) $(OPTIONS) "$(TARGETS)" 2>/dev/null'
-.PHONY: query
diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD
index 27e85a75e..a4a605346 100644
--- a/tools/bazeldefs/BUILD
+++ b/tools/bazeldefs/BUILD
@@ -1,46 +1,7 @@
-load("//tools:defs.bzl", "bzl_library", "rbe_platform", "rbe_toolchain")
+load("//tools:defs.bzl", "bzl_library")
package(licenses = ["notice"])
-# We need to define a bazel platform and toolchain to specify dockerPrivileged
-# and dockerRunAsRoot options, they are required to run tests on the RBE
-# cluster in Kokoro.
-rbe_platform(
- name = "rbe_ubuntu1604",
- constraint_values = [
- "@bazel_tools//platforms:x86_64",
- "@bazel_tools//platforms:linux",
- "@bazel_tools//tools/cpp:clang",
- "@bazel_toolchains//constraints:xenial",
- "@bazel_toolchains//constraints/sanitizers:support_msan",
- ],
- remote_execution_properties = """
- properties: {
- name: "container-image"
- value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272"
- }
- properties: {
- name: "dockerAddCapabilities"
- value: "SYS_ADMIN"
- }
- properties: {
- name: "dockerPrivileged"
- value: "true"
- }
- """,
-)
-
-rbe_toolchain(
- name = "cc-toolchain-clang-x86_64-default",
- exec_compatible_with = [],
- tags = [
- "manual",
- ],
- target_compatible_with = [],
- toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8",
- toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
-)
-
bzl_library(
name = "platforms_bzl",
srcs = ["platforms.bzl"],
@@ -58,3 +19,21 @@ bzl_library(
srcs = ["defs.bzl"],
visibility = ["//visibility:private"],
)
+
+config_setting(
+ name = "linux_arm64_cross",
+ values = {
+ "cpu": "aarch64",
+ "host_cpu": "k8",
+ },
+ visibility = ["//visibility:private"],
+)
+
+config_setting(
+ name = "linux_amd64_cross",
+ values = {
+ "cpu": "k8",
+ "host_cpu": "aarch64",
+ },
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/bazeldefs/cc.bzl b/tools/bazeldefs/cc.bzl
index 7f41a0142..2831eac5f 100644
--- a/tools/bazeldefs/cc.bzl
+++ b/tools/bazeldefs/cc.bzl
@@ -1,11 +1,9 @@
"""C++ rules."""
-load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier")
load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test")
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library")
cc_library = _cc_library
-cc_flags_supplier = _cc_flags_supplier
cc_proto_library = _cc_proto_library
cc_test = _cc_test
cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
@@ -14,6 +12,16 @@ gbenchmark = "@com_google_benchmark//:benchmark"
grpcpp = "@com_github_grpc_grpc//:grpc++"
vdso_linker_option = "-fuse-ld=gold "
+def _cc_flags_supplier_impl(ctx):
+ variables = platform_common.TemplateVariableInfo({
+ "CC_FLAGS": "",
+ })
+ return [variables]
+
+cc_flags_supplier = rule(
+ implementation = _cc_flags_supplier_impl,
+)
+
def cc_grpc_library(name, **kwargs):
_cc_grpc_library(name = name, grpc_only = True, **kwargs)
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index c2f94bb9c..58ced5167 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -5,8 +5,8 @@ load("@bazel_skylib//:bzl_library.bzl", _bzl_library = "bzl_library")
build_test = _build_test
bzl_library = _bzl_library
-rbe_platform = native.platform
-rbe_toolchain = native.toolchain
+more_shards = 4
+most_shards = 8
def short_path(path):
return path
@@ -37,3 +37,44 @@ def default_net_util():
def coreutil():
return [] # Nothing needed.
+
+def select_native_vs_cross(native = [], amd64 = [], arm64 = [], cross = []):
+ values = {
+ "//tools/bazeldefs:linux_arm64_cross": arm64 + cross,
+ "//tools/bazeldefs:linux_amd64_cross": amd64 + cross,
+ "//conditions:default": native,
+ }
+ return select(values)
+
+def arch_genrule(name, srcs, outs, cmd, tools):
+ """Runs a gen command on the target architecture.
+
+ If the target architecture isn't match the host architecture, it will build
+ a command for the target architecture and run it via qemu.
+
+ The native genrule runs the command on the host architecture.
+
+ Args:
+ name: name of generated target.
+ srcs: A list of inputs for this rule.
+ cmd: The command to run. It has to contain " QEMU " before executed binaries.
+ outs: A list of files generated by this rule.
+ tools: A list of tool dependencies for this rule.
+ """
+ qemu_arm64 = "qemu-aarch64-static"
+ qemu_amd64 = "qemu-x86_64-static"
+ srcs = select_native_vs_cross(
+ cross = srcs + tools,
+ native = srcs,
+ )
+ tools = select_native_vs_cross(
+ cross = [],
+ native = tools,
+ )
+ cmd = select_native_vs_cross(
+ arm64 = cmd.replace("QEMU", qemu_arm64),
+ amd64 = cmd.replace("QEMU", qemu_amd64),
+ native = cmd.replace("QEMU", ""),
+ cross = "",
+ )
+ native.genrule(name = name, srcs = srcs, outs = outs, cmd = cmd, tools = tools)
diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl
index 661c9727e..bcd8cffe7 100644
--- a/tools/bazeldefs/go.bzl
+++ b/tools/bazeldefs/go.bzl
@@ -28,7 +28,7 @@ def go_proto_library(name, **kwargs):
def go_grpc_and_proto_libraries(name, **kwargs):
_go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
-def go_binary(name, static = False, pure = False, x_defs = None, **kwargs):
+def go_binary(name, static = False, pure = False, x_defs = None, system_malloc = False, **kwargs):
"""Build a go binary.
Args:
@@ -52,7 +52,7 @@ def go_importpath(target):
"""Returns the importpath for the target."""
return target[GoLibrary].importpath
-def go_library(name, **kwargs):
+def go_library(name, arch_deps = [], **kwargs):
_go_library(
name = name,
importpath = "gvisor.dev/gvisor/" + native.package_name(),
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
index 544af3876..a4ca93ec2 100644
--- a/tools/bigquery/bigquery.go
+++ b/tools/bigquery/bigquery.go
@@ -21,6 +21,7 @@ package bigquery
import (
"context"
"fmt"
+ "strconv"
"strings"
"time"
@@ -109,6 +110,12 @@ func NewBenchmark(name string, iters int) *Benchmark {
return &Benchmark{
Name: name,
Metric: make([]*Metric, 0),
+ Condition: []*Condition{
+ {
+ Name: "iterations",
+ Value: strconv.Itoa(iters),
+ },
+ },
}
}
diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go
index e5a7e23c7..011b8fee8 100644
--- a/tools/checkescape/checkescape.go
+++ b/tools/checkescape/checkescape.go
@@ -27,7 +27,7 @@
// heap: A direct allocation is made on the heap (hard).
// builtin: A call is made to a built-in allocation function (hard).
// stack: A stack split as part of a function preamble (soft).
-// interface: A call is made via an interface whicy *may* escape (soft).
+// interface: A call is made via an interface which *may* escape (soft).
// dynamic: A dynamic function is dispatched which *may* escape (soft).
//
// To the use the package, annotate a function-level comment with either the
@@ -618,12 +618,12 @@ func findReasons(pass *analysis.Pass, fdecl *ast.FuncDecl) ([]EscapeReason, bool
// run performs the analysis.
func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) {
- calls, err := loadObjdump()
- if err != nil {
+ calls, callsErr := loadObjdump()
+ if callsErr != nil {
// Note that if this analysis fails, then we don't actually
// fail the analyzer itself. We simply report every possible
// escape. In most cases this will work just fine.
- log.Printf("WARNING: unable to load objdump: %v", err)
+ log.Printf("WARNING: unable to load objdump: %v", callsErr)
}
allEscapes := make(map[string][]Escapes)
mergedEscapes := make(map[string]Escapes)
@@ -645,10 +645,10 @@ func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) {
}
hasCall := func(inst poser) (string, bool) {
p := linePosition(inst, nil)
- if calls == nil {
+ if callsErr != nil {
// See above: we don't have access to the binary
// itself, so need to include every possible call.
- return "(possible)", true
+ return fmt.Sprintf("(possible, unable to load objdump: %v)", callsErr), true
}
s, ok := calls[p.Simplified()]
if !ok {
diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go
index 27991649f..f46eba39b 100644
--- a/tools/checkescape/test1/test1.go
+++ b/tools/checkescape/test1/test1.go
@@ -36,17 +36,20 @@ func (t Type) Foo() {
fmt.Printf("%v", t) // Never executed.
}
+// InterfaceFunction is passed an interface argument.
// +checkescape:all,hard
//go:nosplit
func InterfaceFunction(i Interface) {
// Do nothing; exported for tests.
}
+// TypeFunction is passed a concrete pointer argument.
// +checkesacape:all,hard
//go:nosplit
func TypeFunction(t *Type) {
}
+// BuiltinMap creates a new map.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -61,7 +64,8 @@ func builtinMapRec(x int) map[string]bool {
return BuiltinMap(x)
}
-// +temustescapestescape:local,builtin
+// BuiltinClosure returns a closure around x.
+// +mustescape:local,builtin
//go:noinline
//go:nosplit
func BuiltinClosure(x int) func() {
@@ -77,6 +81,7 @@ func builtinClosureRec(x int) func() {
return BuiltinClosure(x)
}
+// BuiltinMakeSlice makes a new slice.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -91,6 +96,7 @@ func builtinMakeSliceRec(x int) []byte {
return BuiltinMakeSlice(x)
}
+// BuiltinAppend calls append on a slice.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -105,6 +111,7 @@ func builtinAppendRec() []byte {
return BuiltinAppend(nil)
}
+// BuiltinChan makes a channel.
// +mustescape:local,builtin
//go:noinline
//go:nosplit
@@ -119,6 +126,7 @@ func builtinChanRec() chan int {
return BuiltinChan()
}
+// Heap performs an explicit heap allocation.
// +mustescape:local,heap
//go:noinline
//go:nosplit
@@ -134,6 +142,7 @@ func heapRec() *Type {
return Heap()
}
+// Dispatch dispatches via an interface.
// +mustescape:local,interface
//go:noinline
//go:nosplit
@@ -148,6 +157,7 @@ func dispatchRec(i Interface) {
Dispatch(i)
}
+// Dynamic invokes a dynamic function.
// +mustescape:local,dynamic
//go:noinline
//go:nosplit
@@ -167,6 +177,7 @@ func dynamicRec(f func()) {
func internalFunc() {
}
+// Split includes a guaranteed stack split.
// +mustescape:local,stack
//go:noinline
func Split() {
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 2c8129e7e..56c481f44 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -8,7 +8,7 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
load("//tools/nogo:defs.bzl", "nogo_test")
-load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _proto_library = "proto_library", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path")
+load("//tools/bazeldefs:defs.bzl", _arch_genrule = "arch_genrule", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path")
load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option")
load("//tools/bazeldefs:go.bzl", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos")
load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
@@ -16,6 +16,7 @@ load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform",
load("//tools/bazeldefs:tags.bzl", "go_suffixes")
# Core rules.
+arch_genrule = _arch_genrule
build_test = _build_test
bzl_library = _bzl_library
default_installer = _default_installer
@@ -23,9 +24,9 @@ default_net_util = _default_net_util
select_arch = _select_arch
select_system = _select_system
short_path = _short_path
-rbe_platform = _rbe_platform
-rbe_toolchain = _rbe_toolchain
coreutil = _coreutil
+more_shards = _more_shards
+most_shards = _most_shards
# C++ rules.
cc_binary = _cc_binary
@@ -182,6 +183,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
name + suffix + "_state_autogen.go"
for suffix in state_sets.keys()
]
+
if "//pkg/state" not in all_deps:
all_deps = all_deps + ["//pkg/state"]
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index 768a37b9a..7ef4ddf83 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -89,8 +89,14 @@ git merge --no-commit --strategy ours "${head}" || \
find . -type f -exec chmod 0644 {} \;
find . -type d -exec chmod 0755 {} \;
-# Sync the entire gopath_dir.
-rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" .
+# Sync the entire gopath_dir. Note that we exclude auto-generated source
+# files that will change here. Otherwise, it adds a tremendous amount of noise
+# to commits. If this file disappears in the future, then presumably we will
+# still delete the underlying directory.
+rsync --recursive --delete \
+ --exclude .git \
+ --exclude webhook/pkg/injector/certs.go \
+ -L "${gopath_dir}/" .
# Add additional files.
for file in "${othersrc[@]}"; do
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
index ad97208a8..50e2546bf 100644
--- a/tools/go_generics/defs.bzl
+++ b/tools/go_generics/defs.bzl
@@ -67,7 +67,7 @@ def _go_template_instance_impl(ctx):
# Check that all defined types are expected by the template.
for t in ctx.attr.types:
if (t not in info.types) and (t not in info.opt_types):
- fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label))
+ fail("Type %s is not a parameter to %s" % (t, ctx.attr.template.label))
# Check that all required consts are defined.
for t in info.consts:
@@ -77,7 +77,7 @@ def _go_template_instance_impl(ctx):
# Check that all defined consts are expected by the template.
for t in ctx.attr.consts:
if (t not in info.consts) and (t not in info.opt_consts):
- fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label))
+ fail("Const %s is not a parameter to %s" % (t, ctx.attr.template.label))
# Build the argument list.
args = ["-i=%s" % info.template.path, "-o=%s" % output.path]
diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go
index 0860ca9db..30584006c 100644
--- a/tools/go_generics/generics.go
+++ b/tools/go_generics/generics.go
@@ -223,7 +223,7 @@ func main() {
} else {
switch kind {
case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction:
- if ident.Name != "_" {
+ if ident.Name != "_" && !(ident.Name == "init" && kind == globals.KindFunction) {
ident.Name = *prefix + ident.Name + *suffix
}
case globals.KindTag:
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 4a53d25be..28ae6c4ef 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -213,10 +213,11 @@ type sliceAPI struct {
type marshallableType struct {
spec *ast.TypeSpec
slice *sliceAPI
+ recv string
}
-func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType {
- mt := marshallableType{
+func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType {
+ mt := &marshallableType{
spec: spec,
slice: nil,
}
@@ -261,12 +262,31 @@ func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.Ty
// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType {
- var types []marshallableType
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType {
+ recv := make(map[string]string) // Type name to recevier name.
+ types := make(map[*ast.TypeSpec]*marshallableType)
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
// Type declaration?
if !ok || gdecl.Tok != token.TYPE {
+ // Is this a function declaration? We remember receiver names.
+ d, ok := decl.(*ast.FuncDecl)
+ if ok && d.Recv != nil && len(d.Recv.List) == 1 {
+ // Accept concrete methods & pointer methods.
+ ident, ok := d.Recv.List[0].Type.(*ast.Ident)
+ if !ok {
+ var st *ast.StarExpr
+ st, ok = d.Recv.List[0].Type.(*ast.StarExpr)
+ if ok {
+ ident, ok = st.X.(*ast.Ident)
+ }
+ }
+ // The receiver name may be not present.
+ if ok && len(d.Recv.List[0].Names) == 1 {
+ // Recover the type receiver name in this case.
+ recv[ident.Name] = d.Recv.List[0].Names[0].Name
+ }
+ }
debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
continue
}
@@ -305,10 +325,20 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []ma
// don't support it.
abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
- types = append(types, newMarshallableType(f, tagLine, t))
-
+ types[t] = newMarshallableType(f, tagLine, t)
}
}
+ // Update the types with the last seen receiver. As long as the
+ // receiver name is consistent for the type, then we will generate
+ // code that is still consistent with itself.
+ for t, mt := range types {
+ r, ok := recv[t.Name.Name]
+ if !ok {
+ mt.recv = receiverName(t) // Default.
+ continue
+ }
+ mt.recv = r // Last seen.
+ }
return types
}
@@ -345,8 +375,8 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
-func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator {
- i := newInterfaceGenerator(t.spec, fset)
+func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator {
+ i := newInterfaceGenerator(t.spec, t.recv, fset)
switch ty := t.spec.Type.(type) {
case *ast.StructType:
i.validateStruct(t.spec, ty)
@@ -376,8 +406,8 @@ func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interf
// generateOneTestSuite generates a test suite for the automatically generated
// implementations type t.
-func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator {
- i := newTestGenerator(t.spec)
+func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator {
+ i := newTestGenerator(t.spec, t.recv)
i.emitTests(t.slice)
return i
}
@@ -417,7 +447,15 @@ func (g *Generator) Run() error {
for i, a := range asts {
// Collect type declarations marked for code generation and generate
// Marshallable interfaces.
+ var sortedTypes []*marshallableType
for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
+ sortedTypes = append(sortedTypes, t)
+ }
+ sort.Slice(sortedTypes, func(x, y int) bool {
+ // Sort by type name, which should be unique within a package.
+ return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String()
+ })
+ for _, t := range sortedTypes {
impl := g.generateOne(t, fsets[i])
// Collect Marshallable types referenced by the generated code.
for ref := range impl.ms {
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 36447b86b..65f5ea34d 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -54,10 +54,10 @@ func (g *interfaceGenerator) typeName() string {
}
// newinterfaceGenerator creates a new interface generator.
-func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+func newInterfaceGenerator(t *ast.TypeSpec, r string, fset *token.FileSet) *interfaceGenerator {
g := &interfaceGenerator{
t: t,
- r: receiverName(t),
+ r: r,
f: fset,
is: make(map[string]struct{}),
ms: make(map[string]struct{}),
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 631295373..6cf00843f 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -53,10 +53,10 @@ type testGenerator struct {
decl *importStmt
}
-func newTestGenerator(t *ast.TypeSpec) *testGenerator {
+func newTestGenerator(t *ast.TypeSpec, r string) *testGenerator {
g := &testGenerator{
t: t,
- r: receiverName(t),
+ r: r,
imports: newImportTable(),
}
diff --git a/tools/images.mk b/tools/images.mk
new file mode 100644
index 000000000..2003da5bd
--- /dev/null
+++ b/tools/images.mk
@@ -0,0 +1,169 @@
+#!/usr/bin/make -f
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+##
+## Docker image targets.
+##
+## Images used by the tests must also be built and available locally.
+## The canonical test targets defined below will automatically load
+## relevant images. These can be loaded or built manually via these
+## targets.
+##
+## (*) Note that you may provide an ARCH parameter in order to build
+## and load images from an alternate archiecture (using qemu). When
+## bazel is run as a server, this has the effect of running an full
+## cross-architecture chain, and can produce cross-compiled binaries.
+##
+
+# ARCH is the architecture used for the build. This may be overriden at the
+# command line in order to perform a cross-build (in a limited capacity).
+ARCH := $(shell uname -m)
+ifneq ($(ARCH),$(shell uname -m))
+DOCKER_PLATFORM_ARGS := --platform=$(ARCH)
+else
+DOCKER_PLATFORM_ARGS :=
+endif
+
+# Note that the image prefixes used here must match the image mangling in
+# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all
+# tests are using locally-defined images (that are consistent and idempotent).
+REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit
+LOCAL_IMAGE_PREFIX ?= gvisor.dev/images
+ALL_IMAGES := $(subst /,_,$(subst images/,,$(shell find images/ -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq)))
+SUB_IMAGES := $(foreach image,$(ALL_IMAGES),$(if $(findstring _,$(image)),$(image),))
+IMAGE_GROUPS := $(sort $(foreach image,$(SUB_IMAGES),$(firstword $(subst _, ,$(image)))))
+
+define expand_group =
+load-$(1): $$(patsubst $(1)_%, load-$(1)_%, $$(filter $(1)_%,$$(ALL_IMAGES)))
+ @
+.PHONY: load-$(1)
+push-$(1): $$(patsubst $(1)_%, push-$(1)_%, $$(filter $(1)_%,$$(ALL_IMAGES)))
+ @
+.PHONY: push-$(1)
+endef
+$(foreach group,$(IMAGE_GROUPS),$(eval $(call expand_group,$(group))))
+
+list-all-images: ## List all images.
+ @for image in $(ALL_IMAGES); do echo $${image}; done
+.PHONY: list-all-images
+
+load-all-images: ## Load all images.
+load-all-images: $(patsubst %,load-%,$(ALL_IMAGES))
+.PHONY: load-all-images
+
+push-all-images: ## Push all images.
+push-all-images: $(patsubst %,push-%,$(ALL_IMAGES))
+.PHONY: push-all-images
+
+# path and dockerfile are used to extract the relevant path and dockerfile
+# (depending on what's available for the given architecture).
+path = images/$(subst _,/,$(1))
+dockerfile = $$(if [ -f "$(call path,$(1))/Dockerfile.$(ARCH)" ]; then echo Dockerfile.$(ARCH); else echo Dockerfile; fi)
+
+# The tag construct is used to memoize the image generated (see README.md).
+# This scheme is used to enable aggressive caching in a central repository, but
+# ensuring that images will always be sourced using the local files.
+tag = $(shell cd images && find $(subst _,/,$(1)) -type f | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16)
+remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH)
+local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1))
+
+# Include all existing images as targets here.
+#
+# Note that we use a _ for the tag separator, instead of :, as the latter is
+# interpreted by Make, unfortunately. tag_expand expands the generic rules to
+# tag-specific targets. These is needed to provide sensible targets for load
+# below, with caching. Basically, if there is a rule generated here, then the
+# load will be skipped. If there is no load generated here, then the default
+# rule for load will kick in.
+#
+# Note that if this rule does not successfully rule, we will simply have
+# additional Docker pull commands that run for all images that are already
+# pulled. No real harm done.
+EXISTING_IMAGES = $(shell docker images --format '{{.Repository}}_{{.Tag}}' | grep -v '<none>')
+define existing_image_rule =
+loaded0_$(1)=load-$$(1): tag-$$(1) # Already available.
+loaded1_$(1)=.PHONY: load-$$(1)
+endef
+$(foreach image, $(EXISTING_IMAGES), $(eval $(call existing_image_rule,$(image))))
+define tag_expand_rule =
+$(eval $(loaded0_$(call remote_image,$(1))_$(call tag,$(1))))
+$(eval $(loaded1_$(call remote_image,$(1))_$(call tag,$(1))))
+endef
+$(foreach image, $(ALL_IMAGES), $(eval $(call tag_expand_rule,$(image))))
+
+# tag tags a local image. This applies both the hash-based tag from above to
+# ensure that caching works as expected, as well as the "latest" tag that is
+# used by the tests.
+local_tag = \
+ docker tag $(call remote_image,$(1)):$(call tag,$(1)) $(call local_image,$(1)):$(call tag,$(1)) >&2
+latest_tag = \
+ docker tag $(call local_image,$(1)):$(call tag,$(1)) $(call local_image,$(1)) >&2
+tag-%: ## Tag a local image.
+ @$(call header,TAG $*)
+ @$(call local_tag,$*) && $(call latest_tag,$*)
+
+# pull forces the image to be pulled.
+pull = \
+ $(call header,PULL $(1)) && \
+ docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$(1)):$(call tag,$(1)) >&2 && \
+ $(call local_tag,$(1)) && \
+ $(call latest_tag,$(1))
+pull-%: register-cross ## Force a repull of the image.
+ @$(call pull,$*)
+
+# rebuild builds the image locally. Only the "remote" tag will be applied. Note
+# we need to explicitly repull the base layer in order to ensure that the
+# architecture is correct. Note that we use the term "rebuild" here to avoid
+# conflicting with the bazel "build" terminology, which is used elsewhere.
+rebuild = \
+ $(call header,REBUILD $(1)) && \
+ (T=$$(mktemp -d) && cp -a $(call path,$(1))/* $$T && \
+ $(foreach image,$(shell grep FROM "$(call path,$(1))/$(call dockerfile,$(1))" 2>/dev/null | cut -d' ' -f2),docker pull $(DOCKER_PLATFORM_ARGS) $(image) >&2 &&) \
+ docker build $(DOCKER_PLATFORM_ARGS) \
+ -f "$$T/$(call dockerfile,$(1))" \
+ -t "$(call remote_image,$(1)):$(call tag,$(1))" \
+ $$T >&2 && \
+ rm -rf $$T) && \
+ $(call local_tag,$(1)) && \
+ $(call latest_tag,$(1))
+rebuild-%: register-cross ## Force rebuild an image locally.
+ @$(call rebuild,$*)
+
+# load will either pull the "remote" or build it locally. This is the preferred
+# entrypoint, as it should never fail. The local tag should always be set after
+# this returns (either by the pull or the build).
+load-%: register-cross ## Pull or build an image locally.
+ @($(call pull,$*)) || ($(call rebuild,$*))
+
+# push pushes the remote image, after either pulling (to validate that the tag
+# already exists) or building manually. Note that this generic rule will match
+# the fully-expanded remote image tag.
+push-%: load-% ## Push a given image.
+ @docker push $(call remote_image,$*):$(call tag,$*) >&2
+
+# register-cross registers the necessary qemu binaries for cross-compilation.
+# This may be used by any target that may execute containers that are not the
+# native format. Note that this will only apply on the first execution.
+register-cross:
+ifneq ($(ARCH),$(shell uname -m))
+ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*))
+ @docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes >&2
+else
+ @
+endif
+else
+ @
+endif
diff --git a/tools/installers/BUILD b/tools/installers/BUILD
index 13d3cc5e0..bbf3c1f85 100644
--- a/tools/installers/BUILD
+++ b/tools/installers/BUILD
@@ -1,4 +1,4 @@
-# Installers for use by the tools/vm_test rules.
+# Installers for use by top-level scripts.
package(
default_visibility = ["//:sandbox"],
@@ -14,14 +14,6 @@ sh_binary(
)
sh_binary(
- name = "images",
- srcs = ["images.sh"],
- data = [
- "//images",
- ],
-)
-
-sh_binary(
name = "master",
srcs = ["master.sh"],
)
diff --git a/tools/installers/containerd.sh b/tools/installers/containerd.sh
index 6b7bb261c..d28549734 100755
--- a/tools/installers/containerd.sh
+++ b/tools/installers/containerd.sh
@@ -16,7 +16,7 @@
set -xeo pipefail
-declare -r CONTAINERD_VERSION=${CONTAINERD_VERSION:-1.3.0}
+declare -r CONTAINERD_VERSION=${1:-1.3.0}
declare -r CONTAINERD_MAJOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $1; }')"
declare -r CONTAINERD_MINOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $2; }')"
@@ -43,10 +43,23 @@ install_helper() {
make install)
}
+# Figure out were btrfs headers are.
+#
+# Ubuntu 16.04 has only btrfs-tools, while 18.04 has a transitional package,
+# and later versions no longer have the transitional package.
+source /etc/os-release
+declare BTRFS_DEV
+if [[ "${VERSION_ID%.*}" -le "18" ]]; then
+ BTRFS_DEV="btrfs-tools"
+else
+ BTRFS_DEV="libbtrfs-dev"
+fi
+readonly BTRFS_DEV
+
# Install dependencies for the crictl tests.
while true; do
if (apt-get update && apt-get install -y \
- btrfs-tools \
+ "${BTRFS_DEV}" \
libseccomp-dev); then
break
fi
diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD
index 12b8b597c..566e0889e 100644
--- a/tools/nogo/BUILD
+++ b/tools/nogo/BUILD
@@ -3,6 +3,8 @@ load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib", "nogo_target")
package(licenses = ["notice"])
+exports_files(["config-schema.json"])
+
nogo_target(
name = "target",
goarch = select_goarch(),
diff --git a/tools/nogo/config-schema.json b/tools/nogo/config-schema.json
new file mode 100644
index 000000000..3c25fe221
--- /dev/null
+++ b/tools/nogo/config-schema.json
@@ -0,0 +1,97 @@
+{
+ "$schema": "http://json-schema.org/draft-07/schema",
+ "definitions": {
+ "group": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "description": "The name of the group.",
+ "type": "string"
+ },
+ "regex": {
+ "description": "A regular expression for matching paths.",
+ "type": "string"
+ },
+ "default": {
+ "description": "Whether the group is enabled by default.",
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "name",
+ "regex",
+ "default"
+ ],
+ "additionalProperties": false
+ },
+ "regexlist": {
+ "description": "A list of regular expressions.",
+ "oneOf": [
+ {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
+ "rule": {
+ "type": "object",
+ "properties": {
+ "exclude": {
+ "description": "A regular expression for paths to exclude.",
+ "$ref": "#/definitions/regexlist"
+ },
+ "suppress": {
+ "description": "A regular expression for messages to suppress.",
+ "$ref": "#/definitions/regexlist"
+ }
+ },
+ "additionalProperties": false
+ },
+ "ruleList": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "$ref": "#/definitions/rule"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ }
+ }
+ },
+ "properties": {
+ "groups": {
+ "description": "A definition of all groups.",
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/group"
+ },
+ "minItems": 1
+ },
+ "global": {
+ "description": "A global set of rules.",
+ "type": "object",
+ "additionalProperties": {
+ "$ref": "#/definitions/rule"
+ }
+ },
+ "analyzers": {
+ "description": "A definition of all groups.",
+ "type": "object",
+ "additionalProperties": {
+ "$ref": "#/definitions/ruleList"
+ }
+ }
+ },
+ "required": [
+ "groups"
+ ],
+ "additionalProperties": false
+}
diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go
index 9cf41b3b0..8be38ca6d 100644
--- a/tools/nogo/filter/main.go
+++ b/tools/nogo/filter/main.go
@@ -16,6 +16,7 @@
package main
import (
+ "bytes"
"flag"
"fmt"
"io/ioutil"
@@ -76,12 +77,14 @@ func main() {
log.Fatalf("unable to read %s: %v", filename, err)
}
var newConfig nogo.Config // For current file.
- if err := yaml.Unmarshal(content, &newConfig); err != nil {
+ dec := yaml.NewDecoder(bytes.NewBuffer(content))
+ dec.SetStrict(true)
+ if err := dec.Decode(&newConfig); err != nil {
log.Fatalf("unable to decode %s: %v", filename, err)
}
config.Merge(&newConfig)
if showConfig {
- bytes, err := yaml.Marshal(&newConfig)
+ content, err := yaml.Marshal(&newConfig)
if err != nil {
log.Fatalf("error marshalling config: %v", err)
}
@@ -89,7 +92,7 @@ func main() {
if err != nil {
log.Fatalf("error marshalling config: %v", err)
}
- fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(bytes))
+ fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(content))
fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes))
}
}
diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go
index f0737d46b..39a13b4af 100644
--- a/tools/parsers/go_parser_test.go
+++ b/tools/parsers/go_parser_test.go
@@ -34,6 +34,10 @@ func TestParseLine(t *testing.T) {
Name: "BenchmarkIperf",
Condition: []*bigquery.Condition{
{
+ Name: "iterations",
+ Value: "1",
+ },
+ {
Name: "GOMAXPROCS",
Value: "6",
},
@@ -63,6 +67,10 @@ func TestParseLine(t *testing.T) {
Name: "BenchmarkRuby",
Condition: []*bigquery.Condition{
{
+ Name: "iterations",
+ Value: "1",
+ },
+ {
Name: "GOMAXPROCS",
Value: "6",
},
@@ -100,12 +108,14 @@ func TestParseLine(t *testing.T) {
}
if !cmp.Equal(tc.want, got, nil) {
- for _, c := range got.Condition {
- t.Logf("Cond: %+v", c)
+ for i := range got.Condition {
+ t.Logf("Metric: want: %+v got:%+v", got.Condition[i], tc.want.Condition[i])
}
- for _, m := range got.Metric {
- t.Logf("Metric: %+v", m)
+
+ for i := range got.Metric {
+ t.Logf("Metric: want: %+v got:%+v", got.Metric[i], tc.want.Metric[i])
}
+
t.Fatalf("Compare failed want: %+v got: %+v", tc.want, got)
}
})
@@ -131,7 +141,7 @@ func TestParseOutput(t *testing.T) {
`,
numBenchmarks: 2,
numMetrics: 1,
- numConditions: 1,
+ numConditions: 2,
},
{
name: "Ruby",
@@ -142,7 +152,7 @@ BenchmarkRuby/server_threads.5
BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 465 requests_per_second.QPS`,
numBenchmarks: 2,
numMetrics: 3,
- numConditions: 2,
+ numConditions: 3,
},
}
diff --git a/tools/vm/BUILD b/tools/vm/BUILD
deleted file mode 100644
index d95ca6c63..000000000
--- a/tools/vm/BUILD
+++ /dev/null
@@ -1,63 +0,0 @@
-load("//tools:defs.bzl", "bzl_library", "cc_binary", "gtest")
-load("//tools/vm:defs.bzl", "vm_image", "vm_test")
-
-package(
- default_visibility = ["//:sandbox"],
- licenses = ["notice"],
-)
-
-sh_binary(
- name = "zone",
- srcs = ["zone.sh"],
-)
-
-sh_binary(
- name = "builder",
- srcs = ["build.sh"],
-)
-
-sh_binary(
- name = "executer",
- srcs = ["execute.sh"],
-)
-
-cc_binary(
- name = "test",
- testonly = 1,
- srcs = ["test.cc"],
- linkstatic = 1,
- deps = [
- gtest,
- "//test/util:test_main",
- ],
-)
-
-vm_image(
- name = "ubuntu1604",
- family = "ubuntu-1604-lts",
- project = "ubuntu-os-cloud",
- scripts = [
- "//tools/vm/ubuntu1604",
- ],
-)
-
-vm_image(
- name = "ubuntu1804",
- family = "ubuntu-1804-lts",
- project = "ubuntu-os-cloud",
- scripts = [
- "//tools/vm/ubuntu1804",
- ],
-)
-
-vm_test(
- name = "vm_test",
- shard_count = 2,
- targets = [":test"],
-)
-
-bzl_library(
- name = "defs_bzl",
- srcs = ["defs.bzl"],
- visibility = ["//visibility:private"],
-)
diff --git a/tools/vm/README.md b/tools/vm/README.md
deleted file mode 100644
index 1e9859e66..000000000
--- a/tools/vm/README.md
+++ /dev/null
@@ -1,48 +0,0 @@
-# VM Images & Tests
-
-All commands in this directory require the `gcloud` project to be set.
-
-For example: `gcloud config set project gvisor-kokoro-testing`.
-
-Images can be generated by using the `vm_image` rule. This rule will generate a
-binary target that builds an image in an idempotent way, and can be referenced
-from other rules.
-
-For example:
-
-```
-vm_image(
- name = "ubuntu",
- project = "ubuntu-1604-lts",
- family = "ubuntu-os-cloud",
- scripts = [
- "script.sh",
- "other.sh",
- ],
-)
-```
-
-These images can be built manually by executing the target. The output on
-`stdout` will be the image id (in the current project).
-
-For example:
-
-```
-$ bazel build :ubuntu
-```
-
-Images are always named per the hash of all the hermetic input scripts. This
-allows images to be memoized quickly and easily.
-
-The `vm_test` rule can be used to execute a command remotely. This is still
-under development however, and will likely change over time.
-
-For example:
-
-```
-vm_test(
- name = "mycommand",
- image = ":ubuntu",
- targets = [":test"],
-)
-```
diff --git a/tools/vm/build.sh b/tools/vm/build.sh
deleted file mode 100755
index 752b2b77b..000000000
--- a/tools/vm/build.sh
+++ /dev/null
@@ -1,117 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# This script is responsible for building a new GCP image that: 1) has nested
-# virtualization enabled, and 2) has been completely set up with the
-# image_setup.sh script. This script should be idempotent, as we memoize the
-# setup script with a hash and check for that name.
-
-set -eou pipefail
-
-# Parameters.
-declare -r USERNAME=${USERNAME:-test}
-declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
-declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
-declare -r ZONE=${ZONE:-us-central1-f}
-
-# Random names.
-declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
-declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
-declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
-
-# Hash inputs in order to memoize the produced image.
-declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
-declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH}
-
-# Does the image already exist? Skip the build.
-declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
-if ! [[ -z "${existing}" ]]; then
- echo "${existing}"
- exit 0
-fi
-
-# Standard arguments (applies only on script execution).
-declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
-
-# gcloud has path errors; is this a result of being a genrule?
-export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin}
-
-# Start a unique instance. Note that this instance will have a unique persistent
-# disk as it's boot disk with the same name as the instance.
-(set -x; gcloud compute instances create \
- --quiet \
- --image-project "${IMAGE_PROJECT}" \
- --image-family "${IMAGE_FAMILY}" \
- --boot-disk-size "200GB" \
- --zone "${ZONE}" \
- "${INSTANCE_NAME}" >/dev/null)
-function cleanup {
- (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}")
-}
-trap cleanup EXIT
-
-# Wait for the instance to become available (up to 5 minutes).
-echo -n "Waiting for ${INSTANCE_NAME}" >&2
-declare timeout=300
-declare success=0
-declare internal=""
-declare -r start=$(date +%s)
-declare -r end=$((${start}+${timeout}))
-while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
- echo -n "." >&2
- if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
- success=$((${success}+1))
- elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
- success=$((${success}+1))
- internal="--internal-ip"
- fi
-done
-
-if [[ "${success}" -eq "0" ]]; then
- echo "connect timed out after ${timeout} seconds." >&2
- exit 1
-else
- echo "done." >&2
-fi
-
-# Run the install scripts provided.
-for arg; do
- (set -x; gcloud compute ssh ${internal} \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- sudo bash - <"${arg}" >/dev/null)
-done
-
-# Stop the instance; required before creating an image.
-(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null)
-
-# Create a snapshot of the instance disk.
-(set -x; gcloud compute disks snapshot \
- --quiet \
- --zone "${ZONE}" \
- --snapshot-names="${SNAPSHOT_NAME}" \
- "${INSTANCE_NAME}" >/dev/null)
-
-# Create the disk image.
-(set -x; gcloud compute images create \
- --quiet \
- --source-snapshot="${SNAPSHOT_NAME}" \
- --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
- "${IMAGE_NAME}" >/dev/null)
-
-# Finish up.
-echo "${IMAGE_NAME}"
diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl
deleted file mode 100644
index 9af5ad3b4..000000000
--- a/tools/vm/defs.bzl
+++ /dev/null
@@ -1,202 +0,0 @@
-"""Image configuration. See README.md."""
-
-load("//tools:defs.bzl", "default_installer")
-
-# vm_image_builder is a rule that will construct a shell script that actually
-# generates a given VM image. Note that this does not _run_ the shell script
-# (although it can be run manually). It will be run manually during generation
-# of the vm_image target itself. This level of indirection is used so that the
-# build system itself only runs the builder once when multiple targets depend
-# on it, avoiding a set of races and conflicts.
-def _vm_image_builder_impl(ctx):
- # Generate a binary that actually builds the image.
- builder = ctx.actions.declare_file(ctx.label.name)
- script_paths = []
- for script in ctx.files.scripts:
- script_paths.append(script.short_path)
- builder_content = "\n".join([
- "#!/bin/bash",
- "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
- "export USERNAME=%s" % ctx.attr.username,
- "export IMAGE_PROJECT=%s" % ctx.attr.project,
- "export IMAGE_FAMILY=%s" % ctx.attr.family,
- "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)),
- "",
- ])
- ctx.actions.write(builder, builder_content, is_executable = True)
-
- # Note that the scripts should only be files, and should not include any
- # indirect transitive dependencies. The build script wouldn't work.
- return [DefaultInfo(
- executable = builder,
- runfiles = ctx.runfiles(
- files = ctx.files.scripts + ctx.files._builder + ctx.files.zone,
- ),
- )]
-
-vm_image_builder = rule(
- attrs = {
- "_builder": attr.label(
- executable = True,
- default = "//tools/vm:builder",
- cfg = "host",
- ),
- "username": attr.string(default = "$(whoami)"),
- "zone": attr.label(
- executable = True,
- default = "//tools/vm:zone",
- cfg = "host",
- ),
- "family": attr.string(mandatory = True),
- "project": attr.string(mandatory = True),
- "scripts": attr.label_list(allow_files = True),
- },
- executable = True,
- implementation = _vm_image_builder_impl,
-)
-
-# See vm_image_builder above.
-def _vm_image_impl(ctx):
- # Run the builder to generate our output.
- echo = ctx.actions.declare_file(ctx.label.name)
- resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
- command = "\n".join([
- "set -e",
- "image=$(%s)" % ctx.files.builder[0].path,
- "echo -ne \"#!/bin/bash\\necho ${image}\\n\" > %s" % echo.path,
- "chmod 0755 %s" % echo.path,
- ]),
- tools = [ctx.attr.builder],
- )
- ctx.actions.run_shell(
- tools = resolved_inputs,
- outputs = [echo],
- progress_message = "Building image...",
- execution_requirements = {"local": "true"},
- command = argv,
- input_manifests = runfiles_manifests,
- )
-
- # Return just the echo command. All of the builder runfiles have been
- # resolved and consumed in the generation of the trivial echo script.
- return [DefaultInfo(executable = echo)]
-
-_vm_image_test = rule(
- attrs = {
- "builder": attr.label(
- executable = True,
- cfg = "host",
- ),
- },
- test = True,
- implementation = _vm_image_impl,
-)
-
-def vm_image(name, **kwargs):
- vm_image_builder(
- name = name + "_builder",
- **kwargs
- )
- _vm_image_test(
- name = name,
- builder = ":" + name + "_builder",
- tags = [
- "local",
- "manual",
- ],
- )
-
-def _vm_test_impl(ctx):
- runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
-
- # Note that the remote execution case must actually generate an
- # intermediate target in order to collect all the relevant runfiles so that
- # they can be copied over for remote execution.
- runner_content = "\n".join([
- "#!/bin/bash",
- "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
- "export USERNAME=%s" % ctx.attr.username,
- "export IMAGE=$(%s)" % ctx.files.image[0].short_path,
- "export SUDO=%s" % "true" if ctx.attr.sudo else "false",
- "%s %s" % (
- ctx.executable.executer.short_path,
- " ".join([
- target.files_to_run.executable.short_path
- for target in ctx.attr.targets
- ]),
- ),
- "",
- ])
- ctx.actions.write(runner, runner_content, is_executable = True)
-
- # Return with all transitive files.
- runfiles = ctx.runfiles(
- transitive_files = depset(transitive = [
- depset(target.data_runfiles.files)
- for target in ctx.attr.targets
- if hasattr(target, "data_runfiles")
- ]),
- files = ctx.files.executer + ctx.files.zone + ctx.files.image +
- ctx.files.targets,
- collect_default = True,
- collect_data = True,
- )
- return [DefaultInfo(executable = runner, runfiles = runfiles)]
-
-_vm_test = rule(
- attrs = {
- "image": attr.label(
- executable = True,
- default = "//tools/vm:ubuntu1804",
- cfg = "host",
- ),
- "executer": attr.label(
- executable = True,
- default = "//tools/vm:executer",
- cfg = "host",
- ),
- "username": attr.string(default = "$(whoami)"),
- "zone": attr.label(
- executable = True,
- default = "//tools/vm:zone",
- cfg = "host",
- ),
- "sudo": attr.bool(default = True),
- "machine": attr.string(default = "n1-standard-1"),
- "targets": attr.label_list(
- mandatory = True,
- allow_empty = False,
- cfg = "target",
- ),
- },
- test = True,
- implementation = _vm_test_impl,
-)
-
-def vm_test(
- installers = None,
- **kwargs):
- """Runs the given targets as a remote test.
-
- Args:
- installer: Script to run before all targets.
- **kwargs: All test arguments. Should include targets and image.
- """
- targets = kwargs.pop("targets", [])
- if installers == None:
- installers = [
- "//tools/installers:head",
- "//tools/installers:images",
- ]
- targets = installers + targets
- if default_installer():
- targets = [default_installer()] + targets
- _vm_test(
- tags = [
- "local",
- "manual",
- ],
- targets = targets,
- local = 1,
- **kwargs
- )
diff --git a/tools/vm/execute.sh b/tools/vm/execute.sh
deleted file mode 100755
index 1f1f3ce01..000000000
--- a/tools/vm/execute.sh
+++ /dev/null
@@ -1,160 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Required input.
-if ! [[ -v IMAGE ]]; then
- echo "no image provided: set IMAGE."
- exit 1
-fi
-
-# Parameters.
-declare -r USERNAME=${USERNAME:-test}
-declare -r KEYNAME=$(mktemp --tmpdir -u key-XXXXXX)
-declare -r SSHKEYS=$(mktemp --tmpdir -u sshkeys-XXXXXX)
-declare -r INSTANCE_NAME=$(mktemp -u test-XXXXXX | tr A-Z a-z)
-declare -r MACHINE=${MACHINE:-n1-standard-1}
-declare -r ZONE=${ZONE:-us-central1-f}
-declare -r SUDO=${SUDO:-false}
-
-# Standard arguments (applies only on script execution).
-declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
-
-# This script is executed as a test rule, which will reset the value of HOME.
-# Unfortunately, it is needed to load the gconfig credentials. We will reset
-# HOME when we actually execute in the remote environment, defined below.
-export HOME=$(eval echo ~$(whoami))
-
-# Generate unique keys for this test.
-[[ -f "${KEYNAME}" ]] || ssh-keygen -t rsa -N "" -f "${KEYNAME}" -C "${USERNAME}"
-cat > "${SSHKEYS}" <<EOF
-${USERNAME}:$(cat ${KEYNAME}.pub)
-EOF
-
-# Start a unique instance. This means that we first generate a unique set of ssh
-# keys to ensure that only we have access to this instance. Note that we must
-# constrain ourselves to Haswell or greater in order to have nested
-# virtualization available.
-gcloud compute instances create \
- --min-cpu-platform "Intel Haswell" \
- --preemptible \
- --no-scopes \
- --metadata block-project-ssh-keys=TRUE \
- --metadata-from-file ssh-keys="${SSHKEYS}" \
- --machine-type "${MACHINE}" \
- --image "${IMAGE}" \
- --zone "${ZONE}" \
- "${INSTANCE_NAME}"
-function cleanup {
- gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}"
-}
-trap cleanup EXIT
-
-# Wait for the instance to become available (up to 5 minutes).
-declare timeout=300
-declare success=0
-declare -r start=$(date +%s)
-declare -r end=$((${start}+${timeout}))
-while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
- if gcloud compute ssh --ssh-key-file="${KEYNAME}" --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
- success=$((${success}+1))
- fi
-done
-if [[ "${success}" -eq "0" ]]; then
- echo "connect timed out after ${timeout} seconds."
- exit 1
-fi
-
-# Copy the local directory over.
-tar czf - --dereference --exclude=.git . |
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- tar xzf -
-
-# Execute the command remotely.
-for cmd; do
- # Setup relevant environment.
- #
- # N.B. This is not a complete test environment, but is complete enough to
- # provide rudimentary sharding and test output support.
- declare -a PREFIX=( "env" )
- if [[ -v TEST_SHARD_INDEX ]]; then
- PREFIX+=( "TEST_SHARD_INDEX=${TEST_SHARD_INDEX}" )
- fi
- if [[ -v TEST_SHARD_STATUS_FILE ]]; then
- SHARD_STATUS_FILE=$(mktemp -u test-shard-status-XXXXXX)
- PREFIX+=( "TEST_SHARD_STATUS_FILE=/tmp/${SHARD_STATUS_FILE}" )
- fi
- if [[ -v TEST_TOTAL_SHARDS ]]; then
- PREFIX+=( "TEST_TOTAL_SHARDS=${TEST_TOTAL_SHARDS}" )
- fi
- if [[ -v TEST_TMPDIR ]]; then
- REMOTE_TMPDIR=$(mktemp -u test-XXXXXX)
- PREFIX+=( "TEST_TMPDIR=/tmp/${REMOTE_TMPDIR}" )
- # Create remotely.
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- mkdir -p "/tmp/${REMOTE_TMPDIR}"
- fi
- if [[ -v XML_OUTPUT_FILE ]]; then
- TEST_XML_OUTPUT=$(mktemp -u xml-output-XXXXXX)
- PREFIX+=( "XML_OUTPUT_FILE=/tmp/${TEST_XML_OUTPUT}" )
- fi
- if [[ "${SUDO}" == "true" ]]; then
- PREFIX+=( "sudo" "-E" )
- fi
-
- # Execute the command.
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- "${PREFIX[@]}" "${cmd}"
-
- # Collect relevant results.
- if [[ -v TEST_SHARD_STATUS_FILE ]]; then
- gcloud compute scp \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${SHARD_STATUS_FILE}" \
- "${TEST_SHARD_STATUS_FILE}" 2>/dev/null || true # Allowed to fail.
- fi
- if [[ -v XML_OUTPUT_FILE ]]; then
- gcloud compute scp \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${TEST_XML_OUTPUT}" \
- "${XML_OUTPUT_FILE}" 2>/dev/null || true # Allowed to fail.
- fi
-
- # Clean up the temporary directory.
- if [[ -v TEST_TMPDIR ]]; then
- gcloud compute ssh \
- --ssh-key-file="${KEYNAME}" \
- --zone "${ZONE}" \
- "${USERNAME}"@"${INSTANCE_NAME}" -- \
- "${SSH_ARGS[@]}" \
- rm -rf "/tmp/${REMOTE_TMPDIR}"
- fi
-done
diff --git a/tools/vm/ubuntu1604/10_core.sh b/tools/vm/ubuntu1604/10_core.sh
deleted file mode 100755
index 629f7cf7a..000000000
--- a/tools/vm/ubuntu1604/10_core.sh
+++ /dev/null
@@ -1,43 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Install all essential build tools.
-while true; do
- if (apt-get update && apt-get install -y \
- make \
- git-core \
- build-essential \
- linux-headers-$(uname -r) \
- pkg-config); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Install a recent go toolchain.
-if ! [[ -d /usr/local/go ]]; then
- wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz
- tar -xvf go1.13.5.linux-amd64.tar.gz
- mv go /usr/local
-fi
-
-# Link the Go binary from /usr/bin; replacing anything there.
-(cd /usr/bin && rm -f go && ln -fs /usr/local/go/bin/go go)
diff --git a/tools/vm/ubuntu1604/15_gcloud.sh b/tools/vm/ubuntu1604/15_gcloud.sh
deleted file mode 100755
index bc2e5eccc..000000000
--- a/tools/vm/ubuntu1604/15_gcloud.sh
+++ /dev/null
@@ -1,50 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Install all essential build tools.
-while true; do
- if (apt-get update && apt-get install -y \
- apt-transport-https \
- ca-certificates \
- gnupg); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Add gcloud repositories.
-echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \
- tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
-
-# Add the appropriate key.
-curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \
- apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
-
-# Install the gcloud SDK.
-while true; do
- if (apt-get update && apt-get install -y google-cloud-sdk); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
diff --git a/tools/vm/ubuntu1604/20_bazel.sh b/tools/vm/ubuntu1604/20_bazel.sh
deleted file mode 100755
index bb7afa676..000000000
--- a/tools/vm/ubuntu1604/20_bazel.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-declare -r BAZEL_VERSION=2.0.0
-
-# Install bazel dependencies.
-while true; do
- if (apt-get update && apt-get install -y \
- openjdk-8-jdk-headless \
- unzip); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Use the release installer.
-curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
-chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
-./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
-rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/tools/vm/ubuntu1604/30_docker.sh b/tools/vm/ubuntu1604/30_docker.sh
deleted file mode 100755
index d393133e4..000000000
--- a/tools/vm/ubuntu1604/30_docker.sh
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Add dependencies.
-while true; do
- if (apt-get update && apt-get install -y \
- apt-transport-https \
- ca-certificates \
- curl \
- gnupg-agent \
- software-properties-common); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Install the key.
-curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
-
-# Add the repository.
-add-apt-repository \
- "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
- $(lsb_release -cs) \
- stable"
-
-# Install docker.
-while true; do
- if (apt-get update && apt-get install -y \
- docker-ce \
- docker-ce-cli \
- containerd.io); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# Enable experimental features, for cross-building aarch64 images.
-# Enable Docker IPv6.
-cat > /etc/docker/daemon.json <<EOF
-{
- "experimental": true,
- "fixed-cidr-v6": "2001:db8:1::/64",
- "ipv6": true
-}
-EOF
diff --git a/tools/vm/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh
deleted file mode 100755
index d3b96c9ad..000000000
--- a/tools/vm/ubuntu1604/40_kokoro.sh
+++ /dev/null
@@ -1,72 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Declare kokoro's required public keys.
-declare -r ssh_public_keys=(
- "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com"
- "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com"
-)
-
-# Install dependencies.
-while true; do
- if (apt-get update && apt-get install -y \
- rsync \
- coreutils \
- python-psutil \
- qemu-kvm \
- python-pip \
- python3-pip \
- zip); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- exit $result
- fi
-done
-
-# junitparser is used to merge junit xml files.
-pip install --no-cache-dir junitparser
-
-# We need a kbuilder user, which may already exist.
-useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true
-
-# We need to provision appropriate keys.
-mkdir -p ~kbuilder/.ssh
-(IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys
-chmod 0600 ~kbuilder/.ssh/authorized_keys
-chown -R kbuilder ~kbuilder/.ssh
-
-# Give passwordless sudo access.
-cat > /etc/sudoers.d/kokoro <<EOF
-kbuilder ALL=(ALL) NOPASSWD:ALL
-EOF
-
-# Ensure we can run Docker without sudo.
-usermod -aG docker kbuilder
-
-# Ensure that we can access kvm.
-usermod -aG kvm kbuilder
-
-# Ensure that /tmpfs exists and is writable by kokoro.
-#
-# Note that kokoro will typically attach a second disk (sdb) to the instance
-# that is used for the /tmpfs volume. In the future we could setup an init
-# script that formats and mounts this here; however, we don't expect our build
-# artifacts to be that large.
-mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY
diff --git a/tools/vm/ubuntu1604/BUILD b/tools/vm/ubuntu1604/BUILD
deleted file mode 100644
index ab1df0c4c..000000000
--- a/tools/vm/ubuntu1604/BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-package(licenses = ["notice"])
-
-filegroup(
- name = "ubuntu1604",
- srcs = glob(["*.sh"]),
- visibility = ["//:sandbox"],
-)
diff --git a/tools/vm/ubuntu1804/BUILD b/tools/vm/ubuntu1804/BUILD
deleted file mode 100644
index 0c8856dde..000000000
--- a/tools/vm/ubuntu1804/BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-package(licenses = ["notice"])
-
-alias(
- name = "ubuntu1804",
- actual = "//tools/vm/ubuntu1604",
- visibility = ["//:sandbox"],
-)
diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh
index a22c8c9f2..62d78ed3d 100755
--- a/tools/workspace_status.sh
+++ b/tools/workspace_status.sh
@@ -15,4 +15,4 @@
# limitations under the License.
# The STABLE_ prefix will trigger a re-link if it changes.
-echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0)
+echo STABLE_VERSION "$(git describe --always --tags --abbrev=12 --dirty 2>/dev/null || echo 0.0.0)"
diff --git a/tools/yamltest/BUILD b/tools/yamltest/BUILD
new file mode 100644
index 000000000..475b3badd
--- /dev/null
+++ b/tools/yamltest/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "yamltest",
+ srcs = ["main.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_github_xeipuuv_gojsonschema//:go_default_library",
+ "@in_gopkg_yaml_v2//:go_default_library",
+ ],
+)
diff --git a/tools/yamltest/defs.bzl b/tools/yamltest/defs.bzl
new file mode 100644
index 000000000..fd04f947d
--- /dev/null
+++ b/tools/yamltest/defs.bzl
@@ -0,0 +1,41 @@
+"""Tools for testing yaml files against schemas."""
+
+def _yaml_test_impl(ctx):
+ """Implementation for yaml_test."""
+ runner = ctx.actions.declare_file(ctx.label.name)
+ ctx.actions.write(runner, "\n".join([
+ "#!/bin/bash",
+ "set -euo pipefail",
+ "%s -schema=%s -- %s" % (
+ ctx.files._tool[0].short_path,
+ ctx.files.schema[0].short_path,
+ " ".join([f.short_path for f in ctx.files.srcs]),
+ ),
+ ]), is_executable = True)
+ return [DefaultInfo(
+ runfiles = ctx.runfiles(files = ctx.files._tool + ctx.files.schema + ctx.files.srcs),
+ executable = runner,
+ )]
+
+yaml_test = rule(
+ implementation = _yaml_test_impl,
+ doc = "Tests a yaml file against a schema.",
+ attrs = {
+ "srcs": attr.label_list(
+ doc = "The input yaml files.",
+ mandatory = True,
+ allow_files = True,
+ ),
+ "schema": attr.label(
+ doc = "The schema file in JSON schema format.",
+ allow_single_file = True,
+ mandatory = True,
+ ),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/yamltest:yamltest"),
+ ),
+ },
+ test = True,
+)
diff --git a/tools/yamltest/main.go b/tools/yamltest/main.go
new file mode 100644
index 000000000..88271fb66
--- /dev/null
+++ b/tools/yamltest/main.go
@@ -0,0 +1,133 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary yamltest does strict yaml parsing and validation.
+package main
+
+import (
+ "encoding/json"
+ "errors"
+ "flag"
+ "fmt"
+ "os"
+
+ "github.com/xeipuuv/gojsonschema"
+ yaml "gopkg.in/yaml.v2"
+)
+
+func fixup(v interface{}) (interface{}, error) {
+ switch x := v.(type) {
+ case map[interface{}]interface{}:
+ // Coerse into a string-based map, required for yaml.
+ strMap := make(map[string]interface{})
+ for k, v := range x {
+ strK, ok := k.(string)
+ if !ok {
+ // This cannot be converted to JSON at all.
+ return nil, fmt.Errorf("invalid key %T in (%#v)", k, x)
+ }
+ fv, err := fixup(v)
+ if err != nil {
+ return nil, fmt.Errorf(".%s%w", strK, err)
+ }
+ strMap[strK] = fv
+ }
+ return strMap, nil
+ case []interface{}:
+ for i := range x {
+ fv, err := fixup(x[i])
+ if err != nil {
+ return nil, fmt.Errorf("[%d]%w", i, err)
+ }
+ x[i] = fv
+ }
+ return x, nil
+ default:
+ return v, nil
+ }
+}
+
+func loadFile(filename string) (gojsonschema.JSONLoader, error) {
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ dec := yaml.NewDecoder(f)
+ dec.SetStrict(true)
+ var object interface{}
+ if err := dec.Decode(&object); err != nil {
+ return nil, err
+ }
+ fixedObject, err := fixup(object) // For serialization.
+ if err != nil {
+ return nil, err
+ }
+ bytes, err := json.Marshal(fixedObject)
+ if err != nil {
+ return nil, err
+ }
+ return gojsonschema.NewStringLoader(string(bytes)), nil
+}
+
+var schema = flag.String("schema", "", "path to JSON schema file.")
+
+func main() {
+ flag.Parse()
+ if *schema == "" || len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(2)
+ }
+
+ // Construct our schema loader.
+ schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", *schema))
+
+ // Parse all documents.
+ allErrors := make(map[string][]error)
+ for _, filename := range flag.Args() {
+ // Record the filename with an empty slice for below, where
+ // we will emit all files (even those without any errors).
+ allErrors[filename] = nil
+ documentLoader, err := loadFile(filename)
+ if err != nil {
+ allErrors[filename] = append(allErrors[filename], err)
+ continue
+ }
+ result, err := gojsonschema.Validate(schemaLoader, documentLoader)
+ if err != nil {
+ allErrors[filename] = append(allErrors[filename], err)
+ continue
+ }
+ for _, desc := range result.Errors() {
+ allErrors[filename] = append(allErrors[filename], errors.New(desc.String()))
+ }
+ }
+
+ // Print errors in yaml format.
+ totalErrors := 0
+ for filename, errs := range allErrors {
+ totalErrors += len(errs)
+ if len(errs) == 0 {
+ fmt.Fprintf(os.Stderr, "%s: ✓\n", filename)
+ continue
+ }
+ fmt.Fprintf(os.Stderr, "%s:\n", filename)
+ for _, err := range errs {
+ fmt.Fprintf(os.Stderr, "- %s\n", err)
+ }
+ }
+ if totalErrors != 0 {
+ os.Exit(1)
+ }
+}
diff --git a/website/BUILD b/website/BUILD
index 676c2b701..d5315abce 100644
--- a/website/BUILD
+++ b/website/BUILD
@@ -38,6 +38,7 @@ genrule(
":syscallmd",
"//website/blog:posts",
"//website/cmd/server",
+ "@google_root_pem//file",
],
outs = ["files.tgz"],
cmd = "set -x; " +
@@ -61,6 +62,8 @@ genrule(
"ruby /checks.rb " +
"/output && " +
"cp $(location //website/cmd/server) $$T/output/server && " +
+ "mkdir -p $$T/output/etc/ssl && " +
+ "cp $(location @google_root_pem//file) $$T/output/etc/ssl/cert.pem && " +
"tar -zcf $@ -C $$T/output . && " +
"rm -rf $$T",
tags = [
diff --git a/website/blog/README.md b/website/blog/README.md
new file mode 100644
index 000000000..e1d685288
--- /dev/null
+++ b/website/blog/README.md
@@ -0,0 +1,62 @@
+# gVisor blog
+
+The gVisor blog is owned and run by the gVisor team.
+
+## Contact
+
+Reach out to us on [gitter](https://gitter.im/gvisor/community) or the
+[mailing list](https://groups.google.com/forum/#!forum/gvisor-users) if you
+would like to write a blog post.
+
+## Submit a Post
+
+Anyone can write a blog post and submit it for review. Purely commercial content
+or vendor pitches are not allowed. Please refer to the
+[blog guidelines](#blog-guidelines) for more guidance about content is that
+allowed.
+
+To submit a blog post, follow the steps below.
+
+1. [Sign the Contributor License Agreements](https://gvisor.dev/contributing/)
+ if you have not yet done so.
+1. Familiarize yourself with the Markdown format for the
+ [existing blog posts](https://github.com/google/gvisor/tree/master/website/blog).
+1. Write your blog post in a text editor of your choice.
+1. (Optional) If you need help with markdown, check out
+ [StakEdit](https://stackedit.io/app#) or read
+ [Jekyll's formatting reference](https://jekyllrb.com/docs/posts/#creating-posts)
+ for more information.
+1. Click **Add file** > **Create new file**.
+1. Paste your content into the editor and save it. Name the file in the
+ following way: *[BLOG] Your proposed title* , but don’t put the date in the
+ file name. The blog reviewers will work with you on the final file name, and
+ the date on which the blog will be published.
+1. When you save the file, GitHub will walk you through the pull request (PR)
+ process.
+1. Send us a message on [gitter](https://gitter.im/gvisor/community) with a
+ link to your recently created PR.
+1. A reviewer will be assigned to the pull request. They check your submission,
+ and work with you on feedback and final details. When the pull request is
+ approved, the blog will be scheduled for publication.
+
+### Blog Guidelines {#blog-guidelines}
+
+#### Suitable content:
+
+- **Original content only**
+- gVisor features or project updates
+- Tutorials and demos
+- Use cases
+- Content that is specific to a vendor or platform about gVisor installation
+ and use
+
+#### Unsuitable Content:
+
+- Blogs with no content relevant to gVisor
+- Vendor pitches
+
+## Review Process
+
+Each blog post should be approved by at least one person on the team. Once all
+of the review comments have been addressed and approved, a member of the team
+will schedule publication of the blog post.
diff --git a/website/blog/index.html b/website/blog/index.html
index 5c67c95fc..272917fc4 100644
--- a/website/blog/index.html
+++ b/website/blog/index.html
@@ -20,3 +20,8 @@ pagination:
{% if paginator.total_pages > 1 %}
{% include paginator.html %}
{% endif %}
+
+<hr>
+
+If you would like to contribute to the gVisor blog check out the
+<a href="https://github.com/google/gvisor/tree/master/website/blog">instructions</a>.
diff --git a/website/cmd/server/BUILD b/website/cmd/server/BUILD
index 6b5a08f0d..e4cf91e07 100644
--- a/website/cmd/server/BUILD
+++ b/website/cmd/server/BUILD
@@ -7,4 +7,7 @@ go_binary(
srcs = ["main.go"],
pure = True,
visibility = ["//website:__pkg__"],
+ deps = [
+ "@com_github_google_pprof//driver:go_default_library",
+ ],
)
diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go
index ac09550a9..9f0092ed6 100644
--- a/website/cmd/server/main.go
+++ b/website/cmd/server/main.go
@@ -20,9 +20,13 @@ import (
"fmt"
"log"
"net/http"
+ "net/url"
"os"
+ "path"
"regexp"
"strings"
+
+ "github.com/google/pprof/driver"
)
var redirects = map[string]string{
@@ -58,19 +62,37 @@ var redirects = map[string]string{
// Deprecated, but links continue to work.
"/cl": "https://gvisor-review.googlesource.com",
+
+ // Access package documentation.
+ "/gvisor": "https://pkg.go.dev/gvisor.dev/gvisor",
+
+ // Code search root.
+ "/cs": "https://cs.opensource.google/gvisor/gvisor",
}
-var prefixHelpers = map[string]string{
- "change": "https://github.com/google/gvisor/commit/%s",
- "issue": "https://github.com/google/gvisor/issues/%s",
- "issues": "https://github.com/google/gvisor/issues/%s",
- "pr": "https://github.com/google/gvisor/pull/%s",
+type prefixInfo struct {
+ baseURL string
+ checkValidID bool
+ queryEscape bool
+}
+
+var prefixHelpers = map[string]prefixInfo{
+ "change": {baseURL: "https://github.com/google/gvisor/commit/%s", checkValidID: true},
+ "issue": {baseURL: "https://github.com/google/gvisor/issues/%s", checkValidID: true},
+ "issues": {baseURL: "https://github.com/google/gvisor/issues/%s", checkValidID: true},
+ "pr": {baseURL: "https://github.com/google/gvisor/pull/%s", checkValidID: true},
// Redirects to compatibility docs.
- "c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/#%s",
+ "c/linux/amd64": {baseURL: "/docs/user_guide/compatibility/linux/amd64/#%s", checkValidID: true},
// Deprecated, but links continue to work.
- "cl": "https://gvisor-review.googlesource.com/c/gvisor/+/%s",
+ "cl": {baseURL: "https://gvisor-review.googlesource.com/c/gvisor/+/%s", checkValidID: true},
+
+ // Redirect to source documentation.
+ "gvisor": {baseURL: "https://pkg.go.dev/gvisor.dev/gvisor/%s"},
+
+ // Redirect to code search, with the path as the query.
+ "cs": {baseURL: "https://cs.opensource.google/search?q=%s&ss=gvisor", queryEscape: true},
}
var (
@@ -144,7 +166,7 @@ func hostRedirectHandler(h http.Handler) http.Handler {
}
// prefixRedirectHandler returns a handler that redirects to the given formated url.
-func prefixRedirectHandler(prefix, baseURL string) http.Handler {
+func prefixRedirectHandler(prefix string, info prefixInfo) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if p := r.URL.Path; p == prefix {
// Redirect /prefix/ to /prefix.
@@ -152,11 +174,14 @@ func prefixRedirectHandler(prefix, baseURL string) http.Handler {
return
}
id := r.URL.Path[len(prefix):]
- if !validID.MatchString(id) {
+ if info.checkValidID && !validID.MatchString(id) {
http.Error(w, "Not found", http.StatusNotFound)
return
}
- target := fmt.Sprintf(baseURL, id)
+ if info.queryEscape {
+ id = url.QueryEscape(id)
+ }
+ target := fmt.Sprintf(info.baseURL, id)
redirectWithQuery(w, r, target)
})
}
@@ -168,30 +193,178 @@ func redirectHandler(target string) http.Handler {
})
}
-// redirectRedirects registers redirect http handlers.
+// registerRedirects registers redirect http handlers.
func registerRedirects(mux *http.ServeMux) {
- if mux == nil {
- mux = http.DefaultServeMux
- }
-
- for prefix, baseURL := range prefixHelpers {
+ for prefix, info := range prefixHelpers {
p := "/" + prefix + "/"
- mux.Handle(p, hostRedirectHandler(wrappedHandler(prefixRedirectHandler(p, baseURL))))
+ mux.Handle(p, hostRedirectHandler(wrappedHandler(prefixRedirectHandler(p, info))))
}
-
for path, redirect := range redirects {
mux.Handle(path, hostRedirectHandler(wrappedHandler(redirectHandler(redirect))))
}
}
-// registerStatic registers static file handlers
+// registerStatic registers static file handlers.
func registerStatic(mux *http.ServeMux, staticDir string) {
- if mux == nil {
- mux = http.DefaultServeMux
- }
mux.Handle("/", hostRedirectHandler(wrappedHandler(http.FileServer(http.Dir(staticDir)))))
}
+// profileMeta implements synthetic flags for pprof.
+type profileMeta struct {
+ // Mux is the mux to register on.
+ Mux *http.ServeMux
+
+ // SourceURL is the source of the profile.
+ SourceURL string
+}
+
+func (*profileMeta) ExtraUsage() string { return "" }
+func (*profileMeta) AddExtraUsage(string) {}
+func (*profileMeta) Bool(_ string, def bool, _ string) *bool { return &def }
+func (*profileMeta) Int(_ string, def int, _ string) *int { return &def }
+func (*profileMeta) Float64(_ string, def float64, _ string) *float64 { return &def }
+func (*profileMeta) StringList(_ string, def string, _ string) *[]*string { return new([]*string) }
+func (*profileMeta) String(option string, def string, _ string) *string {
+ switch option {
+ case "http":
+ // Only http is specified. Other options may be accessible via
+ // the web interface, so we just need to spoof a valid option
+ // here. The server is actually bound by HTTPServer, below.
+ value := "localhost:80"
+ return &value
+ case "symbolize":
+ // Don't attempt symbolization. Most profiles should come with
+ // mappings built-in to the profile itself.
+ value := "none"
+ return &value
+ default:
+ return &def // Default.
+ }
+}
+
+// Parse implements plugin.FlagSet.Parse.
+func (p *profileMeta) Parse(usage func()) []string {
+ // Just return the SourceURL. This is interpreted as the profile to
+ // download. We validate that the URL corresponds to a Google Cloud
+ // Storage URL below.
+ return []string{p.SourceURL}
+}
+
+// pprofFixedPrefix is used to limit the exposure to SSRF.
+//
+// See registerProfile below.
+const pprofFixedPrefix = "https://storage.googleapis.com/"
+
+// allowedBuckets enforces constraints on the pprof target.
+//
+// If the continuous integration system is changed in the future to use
+// additional buckets, they may be whitelisted here. See registerProfile.
+var allowedBuckets = map[string]bool{
+ "gvisor-buildkite": true,
+}
+
+// Target returns the URL target.
+func (p *profileMeta) Target() string {
+ return fmt.Sprintf("/profile/%s/", p.SourceURL[len(pprofFixedPrefix):])
+}
+
+// HTTPServer is a function passed to driver.PProf.
+func (p *profileMeta) HTTPServer(args *driver.HTTPServerArgs) error {
+ target := p.Target()
+ for subpath, handler := range args.Handlers {
+ handlerPath := path.Join(target, subpath)
+ if len(handlerPath) < len(target) {
+ // Don't clean the target, match only as the literal
+ // directory path in order to keep relative links
+ // working in the profile. E.g. /profile/foo/ is the
+ // base URL for the profile at https://.../foo.
+ //
+ // The base target typically shows the dot-based graph,
+ // which will not work in the image (due to the lack of
+ // a dot binary to execute). Therefore, we redirect to
+ // the flamegraph handler. Everything should otherwise
+ // work the exact same way, except the "Graph" link.
+ handlerPath = target
+ handler = redirectHandler(path.Join(handlerPath, "flamegraph"))
+ }
+ p.Mux.Handle(handlerPath, handler)
+ }
+ return nil
+}
+
+// registerProfile registers the profile handler.
+//
+// Note that this has a security surface worth considering.
+//
+// We are passed effectively a URL, which we fetch and parse,
+// then display the profile output. We limit the possibility of
+// SSRF by interpreting the URL strictly as a part to an object
+// in Google Cloud Storage, and further limit the buckets that
+// may be used. This contains the vast majority of concerns,
+// since objects must at least be uploaded by our CI system.
+//
+// However, we additionally consider the possibility that users
+// craft malicious profile objects (somehow) and pass those URLs
+// here as well. It seems feasible that we could parse a profile
+// that causes a crash (DOS), but this would be automatically
+// handled without a blip. It seems unlikely that we could parse a
+// profile that gives full code execution, but even so there is
+// nothing in this image except this code and CA certs. At worst,
+// code execution would enable someone to serve up content under the
+// web domain. This would be ephemeral with the specific instance,
+// and persisting such an attack would require constantly crashing
+// instances in whatever way gives remote code execution. Even if
+// this were possible, it's unlikely that exploiting such a crash
+// could be done so constantly and consistently.
+//
+// The user can also fill the "disk" of this container instance,
+// causing an OOM and a crash. This has similar semantics to the
+// DOS scenario above, and would just be handled by Cloud Run.
+//
+// Note that all of the above scenarios would require uploading
+// malicious profiles to controller buckets, and a clear audit
+// trail would exist in those cases.
+func registerProfile(mux *http.ServeMux) {
+ const urlPrefix = "/profile/"
+ mux.Handle(urlPrefix, hostRedirectHandler(wrappedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Extract the URL; this is everything except the final /.
+ parts := strings.Split(r.URL.Path[len(urlPrefix):], "/")
+ if len(parts) == 0 {
+ http.Error(w, "Invalid URL: no bucket provided.", http.StatusNotFound)
+ return
+ }
+ if !allowedBuckets[parts[0]] {
+ http.Error(w, fmt.Sprintf("Invalid URL: not an allowed bucket (%s).", parts[0]), http.StatusNotFound)
+ return
+ }
+ url := pprofFixedPrefix + strings.Join(parts[:len(parts)-1], "/")
+ if url == pprofFixedPrefix {
+ http.Error(w, "Invalid URL: no path provided.", http.StatusNotFound)
+ return
+ }
+
+ // Set up the meta handler. This will modify the original mux
+ // accordingly, and we ultimately return a redirect that
+ // includes all the original arguments. This means that if we
+ // ever hit a server that does not have this profile loaded, it
+ // will load and redirect again.
+ meta := &profileMeta{
+ Mux: mux,
+ SourceURL: url,
+ }
+ if err := driver.PProf(&driver.Options{
+ Flagset: meta,
+ HTTPServer: meta.HTTPServer,
+ }); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid profile: %v", err), http.StatusNotImplemented)
+ return
+ }
+
+ // Serve the path directly.
+ mux.ServeHTTP(w, r)
+ }))))
+}
+
func envFlagString(name, def string) string {
if val := os.Getenv(name); val != "" {
return val
@@ -211,8 +384,9 @@ var (
func main() {
flag.Parse()
- registerRedirects(nil)
- registerStatic(nil, *staticDir)
+ registerRedirects(http.DefaultServeMux)
+ registerStatic(http.DefaultServeMux, *staticDir)
+ registerProfile(http.DefaultServeMux)
log.Printf("Listening on %s...", *addr)
log.Fatal(http.ListenAndServe(*addr, nil))