summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc2
-rw-r--r--BUILD3
-rw-r--r--CONTRIBUTING.md32
-rw-r--r--Makefile2
-rw-r--r--README.md9
-rw-r--r--SECURITY.md11
-rw-r--r--WORKSPACE84
-rw-r--r--cloudbuild/go.Dockerfile2
-rw-r--r--cloudbuild/go.yaml22
-rw-r--r--kokoro/build.cfg23
-rw-r--r--kokoro/build_tests.cfg1
-rw-r--r--kokoro/common.cfg2
-rw-r--r--kokoro/continuous.cfg13
-rw-r--r--kokoro/do_tests.cfg9
-rw-r--r--kokoro/docker_tests.cfg10
-rw-r--r--kokoro/go.cfg20
-rw-r--r--kokoro/go_tests.cfg1
-rw-r--r--kokoro/hostnet_tests.cfg10
-rw-r--r--kokoro/kvm_tests.cfg10
-rw-r--r--kokoro/make_tests.cfg9
-rw-r--r--kokoro/overlay_tests.cfg10
-rw-r--r--kokoro/presubmit.cfg13
-rw-r--r--kokoro/release-nightly.cfg11
-rw-r--r--kokoro/release.cfg1
-rw-r--r--kokoro/root_tests.cfg10
l---------kokoro/run_build.sh1
l---------kokoro/run_tests.sh1
-rw-r--r--kokoro/simple_tests.cfg9
-rw-r--r--kokoro/syscall_tests.cfg9
-rwxr-xr-xkokoro/ubuntu1604/10_core.sh30
-rwxr-xr-xkokoro/ubuntu1604/20_bazel.sh28
-rwxr-xr-xkokoro/ubuntu1604/25_docker.sh35
-rwxr-xr-xkokoro/ubuntu1604/30_containerd.sh76
-rwxr-xr-xkokoro/ubuntu1604/40_kokoro.sh54
-rwxr-xr-xkokoro/ubuntu1604/build.sh20
l---------kokoro/ubuntu1804/10_core.sh1
l---------kokoro/ubuntu1804/20_bazel.sh1
l---------kokoro/ubuntu1804/25_docker.sh1
l---------kokoro/ubuntu1804/30_containerd.sh1
l---------kokoro/ubuntu1804/40_kokoro.sh1
-rwxr-xr-xkokoro/ubuntu1804/build.sh20
-rw-r--r--pkg/abi/linux/BUILD5
-rw-r--r--pkg/abi/linux/file.go2
-rw-r--r--pkg/abi/linux/signalfd.go45
-rw-r--r--pkg/amutex/BUILD3
-rw-r--r--pkg/atomicbitops/BUILD4
-rw-r--r--pkg/atomicbitops/atomic_bitops.go5
-rw-r--r--pkg/atomicbitops/atomic_bitops_arm64.s139
-rw-r--r--pkg/atomicbitops/atomic_bitops_common.go2
-rw-r--r--pkg/binary/BUILD3
-rw-r--r--pkg/bits/BUILD6
-rw-r--r--pkg/bits/uint64_arch.go (renamed from pkg/bits/uint64_arch_amd64.go)2
-rw-r--r--pkg/bits/uint64_arch_arm64_asm.s33
-rw-r--r--pkg/bits/uint64_arch_generic.go2
-rw-r--r--pkg/bpf/BUILD4
-rw-r--r--pkg/compressio/BUILD3
-rw-r--r--pkg/cpuid/BUILD4
-rw-r--r--pkg/cpuid/cpuid.go203
-rw-r--r--pkg/cpuid/cpuid_test.go24
-rw-r--r--pkg/eventchannel/BUILD3
-rw-r--r--pkg/fd/BUILD6
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/fdchannel/BUILD3
-rw-r--r--pkg/flipcall/BUILD4
-rw-r--r--pkg/flipcall/ctrl_futex.go32
-rw-r--r--pkg/flipcall/flipcall.go72
-rw-r--r--pkg/flipcall/flipcall_example_test.go4
-rw-r--r--pkg/flipcall/flipcall_test.go204
-rw-r--r--pkg/flipcall/flipcall_unsafe.go28
-rw-r--r--pkg/flipcall/futex_linux.go23
-rw-r--r--pkg/fspath/BUILD3
-rw-r--r--pkg/gate/BUILD3
-rw-r--r--pkg/ilist/BUILD3
-rw-r--r--pkg/linewriter/BUILD3
-rw-r--r--pkg/log/BUILD3
-rw-r--r--pkg/metric/BUILD10
-rw-r--r--pkg/p9/BUILD6
-rw-r--r--pkg/p9/client.go288
-rw-r--r--pkg/p9/client_test.go50
-rw-r--r--pkg/p9/handlers.go53
-rw-r--r--pkg/p9/messages.go103
-rw-r--r--pkg/p9/p9.go2
-rw-r--r--pkg/p9/p9test/BUILD6
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/p9/p9test/p9test.go4
-rw-r--r--pkg/p9/server.go180
-rw-r--r--pkg/p9/transport.go5
-rw-r--r--pkg/p9/transport_flipcall.go243
-rw-r--r--pkg/p9/transport_test.go4
-rw-r--r--pkg/p9/version.go9
-rw-r--r--pkg/procid/BUILD3
-rw-r--r--pkg/refs/BUILD4
-rw-r--r--pkg/refs/refcounter.go6
-rw-r--r--pkg/seccomp/BUILD4
-rw-r--r--pkg/seccomp/seccomp_unsafe.go2
-rw-r--r--pkg/secio/BUILD3
-rw-r--r--pkg/segment/test/BUILD3
-rw-r--r--pkg/sentry/BUILD2
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/control/BUILD3
-rw-r--r--pkg/sentry/device/BUILD4
-rw-r--r--pkg/sentry/fs/BUILD4
-rw-r--r--pkg/sentry/fs/dirent.go4
-rw-r--r--pkg/sentry/fs/dirent_refs_test.go2
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD4
-rw-r--r--pkg/sentry/fs/file.go23
-rw-r--r--pkg/sentry/fs/file_operations.go9
-rw-r--r--pkg/sentry/fs/file_overlay.go9
-rw-r--r--pkg/sentry/fs/fsutil/BUILD4
-rw-r--r--pkg/sentry/fs/fsutil/file.go6
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go74
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go14
-rw-r--r--pkg/sentry/fs/gofer/BUILD4
-rw-r--r--pkg/sentry/fs/gofer/fs.go22
-rw-r--r--pkg/sentry/fs/gofer/inode.go15
-rw-r--r--pkg/sentry/fs/gofer/session.go27
-rw-r--r--pkg/sentry/fs/host/BUILD4
-rw-r--r--pkg/sentry/fs/host/inode.go10
-rw-r--r--pkg/sentry/fs/host/tty.go15
-rw-r--r--pkg/sentry/fs/inode_overlay.go6
-rw-r--r--pkg/sentry/fs/inotify.go5
-rw-r--r--pkg/sentry/fs/lock/BUILD4
-rw-r--r--pkg/sentry/fs/mounts.go3
-rw-r--r--pkg/sentry/fs/proc/BUILD5
-rw-r--r--pkg/sentry/fs/proc/net.go276
-rw-r--r--pkg/sentry/fs/proc/proc.go2
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD4
-rw-r--r--pkg/sentry/fs/ramfs/BUILD4
-rw-r--r--pkg/sentry/fs/splice.go162
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD4
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go2
-rw-r--r--pkg/sentry/fs/tty/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD6
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD2
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go12
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go4
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go3
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go4
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/memfs/directory.go26
-rw-r--r--pkg/sentry/fsimpl/memfs/memfs.go2
-rw-r--r--pkg/sentry/fsimpl/memfs/regular_file.go2
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD3
-rw-r--r--pkg/sentry/hostcpu/BUILD4
-rw-r--r--pkg/sentry/hostcpu/getcpu_arm64.s28
-rw-r--r--pkg/sentry/kernel/BUILD10
-rw-r--r--pkg/sentry/kernel/epoll/BUILD4
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD4
-rw-r--r--pkg/sentry/kernel/futex/BUILD4
-rw-r--r--pkg/sentry/kernel/kernel.go129
-rw-r--r--pkg/sentry/kernel/memevent/BUILD7
-rw-r--r--pkg/sentry/kernel/pipe/BUILD4
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go25
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go4
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go86
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go76
-rw-r--r--pkg/sentry/kernel/posixtimer.go8
-rw-r--r--pkg/sentry/kernel/sched/BUILD3
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD4
-rw-r--r--pkg/sentry/kernel/sessions.go8
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD22
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go140
-rw-r--r--pkg/sentry/kernel/task.go8
-rw-r--r--pkg/sentry/kernel/task_block.go12
-rw-r--r--pkg/sentry/kernel/task_identity.go4
-rw-r--r--pkg/sentry/kernel/task_sched.go33
-rw-r--r--pkg/sentry/kernel/task_signals.go18
-rw-r--r--pkg/sentry/kernel/thread_group.go3
-rw-r--r--pkg/sentry/kernel/time/time.go27
-rw-r--r--pkg/sentry/limits/BUILD4
-rw-r--r--pkg/sentry/loader/elf.go20
-rw-r--r--pkg/sentry/loader/loader.go3
-rw-r--r--pkg/sentry/memmap/BUILD4
-rw-r--r--pkg/sentry/mm/BUILD4
-rw-r--r--pkg/sentry/pgalloc/BUILD4
-rw-r--r--pkg/sentry/platform/interrupt/BUILD3
-rw-r--r--pkg/sentry/platform/kvm/BUILD4
-rw-r--r--pkg/sentry/platform/kvm/testutil/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil.go3
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_amd64.go3
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go59
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s91
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go16
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go20
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go4
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD3
-rw-r--r--pkg/sentry/platform/safecopy/BUILD3
-rw-r--r--pkg/sentry/safemem/BUILD3
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go2
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/socket.go22
-rw-r--r--pkg/sentry/socket/netstack/BUILD (renamed from pkg/sentry/socket/epsocket/BUILD)6
-rw-r--r--pkg/sentry/socket/netstack/device.go (renamed from pkg/sentry/socket/epsocket/device.go)6
-rw-r--r--pkg/sentry/socket/netstack/netstack.go (renamed from pkg/sentry/socket/epsocket/epsocket.go)360
-rw-r--r--pkg/sentry/socket/netstack/provider.go (renamed from pkg/sentry/socket/epsocket/provider.go)4
-rw-r--r--pkg/sentry/socket/netstack/save_restore.go (renamed from pkg/sentry/socket/epsocket/save_restore.go)2
-rw-r--r--pkg/sentry/socket/netstack/stack.go (renamed from pkg/sentry/socket/epsocket/stack.go)2
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD9
-rw-r--r--pkg/sentry/socket/unix/BUILD2
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go82
-rw-r--r--pkg/sentry/socket/unix/unix.go24
-rw-r--r--pkg/sentry/strace/BUILD9
-rw-r--r--pkg/sentry/strace/linux64.go1
-rw-r--r--pkg/sentry/strace/socket.go4
-rw-r--r--pkg/sentry/syscalls/linux/BUILD3
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go379
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go386
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go313
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go77
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go129
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go39
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go6
-rw-r--r--pkg/sentry/time/BUILD3
-rw-r--r--pkg/sentry/unimpl/BUILD7
-rw-r--r--pkg/sentry/usage/memory.go5
-rw-r--r--pkg/sentry/usermem/BUILD4
-rw-r--r--pkg/sentry/usermem/usermem.go8
-rw-r--r--pkg/sentry/vfs/BUILD3
-rw-r--r--pkg/sentry/vfs/file_description.go11
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go4
-rw-r--r--pkg/sentry/vfs/options.go6
-rw-r--r--pkg/sleep/BUILD4
-rw-r--r--pkg/sleep/commit_arm64.s38
-rw-r--r--pkg/sleep/commit_asm.go2
-rw-r--r--pkg/sleep/commit_noasm.go2
-rw-r--r--pkg/state/BUILD3
-rw-r--r--pkg/state/statefile/BUILD3
-rw-r--r--pkg/syserror/BUILD3
-rw-r--r--pkg/tcpip/BUILD5
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD3
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go5
-rw-r--r--pkg/tcpip/buffer/BUILD5
-rw-r--r--pkg/tcpip/checker/checker.go100
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD3
-rw-r--r--pkg/tcpip/header/BUILD5
-rw-r--r--pkg/tcpip/header/icmpv4.go71
-rw-r--r--pkg/tcpip/header/icmpv6.go88
-rw-r--r--pkg/tcpip/header/ipv4.go35
-rw-r--r--pkg/tcpip/header/ipv6.go55
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/iptables/BUILD4
-rw-r--r--pkg/tcpip/link/channel/channel.go9
-rw-r--r--pkg/tcpip/link/fdbased/BUILD7
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go63
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go3
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go179
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64.go194
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go (renamed from test/runtimes/runtimes.go)15
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go (renamed from pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go)2
-rw-r--r--pkg/tcpip/link/loopback/loopback.go7
-rw-r--r--pkg/tcpip/link/muxed/BUILD3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go12
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/rawfile/BUILD5
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s42
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_unsafe.go)4
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go)8
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go11
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go21
-rw-r--r--pkg/tcpip/link/waitable/BUILD3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go10
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go9
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD3
-rw-r--r--pkg/tcpip/network/arp/arp.go18
-rw-r--r--pkg/tcpip/network/arp/arp_test.go15
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD8
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go16
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go10
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go10
-rw-r--r--pkg/tcpip/network/ip_test.go29
-rw-r--r--pkg/tcpip/network/ipv4/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go109
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go152
-rw-r--r--pkg/tcpip/network/ipv6/BUILD10
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go85
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go59
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go74
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go266
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go181
-rw-r--r--pkg/tcpip/ports/BUILD3
-rw-r--r--pkg/tcpip/ports/ports.go161
-rw-r--r--pkg/tcpip/ports/ports_test.go169
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go9
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go9
-rw-r--r--pkg/tcpip/seqnum/BUILD4
-rw-r--r--pkg/tcpip/stack/BUILD32
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go41
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go253
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go79
-rw-r--r--pkg/tcpip/stack/nic.go460
-rw-r--r--pkg/tcpip/stack/registration.go90
-rw-r--r--pkg/tcpip/stack/route.go32
-rw-r--r--pkg/tcpip/stack/stack.go294
-rw-r--r--pkg/tcpip/stack/stack_test.go1036
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go268
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go352
-rw-r--r--pkg/tcpip/stack/transport_test.go102
-rw-r--r--pkg/tcpip/tcpip.go315
-rw-r--r--pkg/tcpip/tcpip_test.go31
-rw-r--r--pkg/tcpip/transport/icmp/BUILD4
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go221
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go38
-rw-r--r--pkg/tcpip/transport/raw/BUILD4
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go369
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/raw/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/BUILD8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go48
-rw-r--r--pkg/tcpip/transport/tcp/connect.go74
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go56
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go586
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go28
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go26
-rw-r--r--pkg/tcpip/transport/tcp/snd.go12
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go576
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go52
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD3
-rw-r--r--pkg/tcpip/transport/udp/BUILD8
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go373
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go18
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go8
-rw-r--r--pkg/tcpip/transport/udp/protocol.go113
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go632
-rw-r--r--pkg/tmutex/BUILD3
-rw-r--r--pkg/unet/BUILD3
-rw-r--r--pkg/urpc/BUILD3
-rw-r--r--pkg/waiter/BUILD7
-rw-r--r--runsc/BUILD29
-rw-r--r--runsc/boot/BUILD4
-rw-r--r--runsc/boot/config.go63
-rw-r--r--runsc/boot/controller.go14
-rw-r--r--runsc/boot/filter/config.go14
-rw-r--r--runsc/boot/fs.go256
-rw-r--r--runsc/boot/loader.go127
-rw-r--r--runsc/boot/loader_test.go17
-rw-r--r--runsc/boot/network.go18
-rw-r--r--runsc/boot/user.go28
-rw-r--r--runsc/boot/user_test.go15
-rw-r--r--runsc/cgroup/BUILD4
-rw-r--r--runsc/cmd/BUILD3
-rw-r--r--runsc/cmd/capability_test.go4
-rw-r--r--runsc/cmd/exec.go53
-rw-r--r--runsc/cmd/exec_test.go4
-rw-r--r--runsc/cmd/gofer.go7
-rw-r--r--runsc/cmd/install.go210
-rw-r--r--runsc/container/BUILD3
-rw-r--r--runsc/container/console_test.go2
-rw-r--r--runsc/container/container.go113
-rw-r--r--runsc/container/container_test.go74
-rw-r--r--runsc/container/multi_container_test.go118
-rw-r--r--runsc/container/shared_volume_test.go2
-rw-r--r--runsc/container/test_app/BUILD2
-rw-r--r--runsc/container/test_app/fds.go4
-rw-r--r--runsc/container/test_app/test_app.go67
-rw-r--r--runsc/criutil/BUILD12
-rw-r--r--runsc/criutil/criutil.go (renamed from runsc/test/testutil/crictl.go)80
-rwxr-xr-xrunsc/debian/postinst.sh6
-rw-r--r--runsc/dockerutil/BUILD15
-rw-r--r--runsc/dockerutil/dockerutil.go (renamed from runsc/test/testutil/docker.go)115
-rw-r--r--runsc/fsgofer/filter/BUILD1
-rw-r--r--runsc/fsgofer/filter/config.go50
-rw-r--r--runsc/fsgofer/filter/filter.go13
-rw-r--r--runsc/fsgofer/fsgofer.go55
-rw-r--r--runsc/fsgofer/fsgofer_test.go12
-rw-r--r--runsc/main.go39
-rw-r--r--runsc/sandbox/sandbox.go10
-rw-r--r--runsc/specutils/BUILD1
-rw-r--r--runsc/specutils/namespace.go14
-rw-r--r--runsc/specutils/specutils.go67
-rw-r--r--runsc/test/BUILD0
-rw-r--r--runsc/test/README.md24
-rw-r--r--runsc/test/build_defs.bzl19
-rwxr-xr-xrunsc/test/install.sh93
-rw-r--r--runsc/test/integration/exec_test.go161
-rw-r--r--runsc/testutil/BUILD (renamed from runsc/test/testutil/BUILD)11
-rw-r--r--runsc/testutil/testutil.go (renamed from runsc/test/testutil/testutil.go)96
-rw-r--r--runsc/tools/dockercfg/BUILD10
-rw-r--r--runsc/tools/dockercfg/dockercfg.go193
-rw-r--r--runsc/version.go2
-rwxr-xr-xrunsc/version_test.sh36
-rwxr-xr-xscripts/build.sh79
-rwxr-xr-xscripts/common.sh80
-rwxr-xr-xscripts/common_bazel.sh99
-rwxr-xr-xscripts/dev.sh73
-rwxr-xr-xscripts/do_tests.sh27
-rwxr-xr-xscripts/docker_tests.sh20
-rwxr-xr-xscripts/go.sh43
-rwxr-xr-xscripts/hostnet_tests.sh21
-rwxr-xr-xscripts/kvm_tests.sh28
-rwxr-xr-xscripts/make_tests.sh25
-rwxr-xr-xscripts/overlay_tests.sh21
-rwxr-xr-xscripts/release.sh38
-rwxr-xr-xscripts/root_tests.sh31
-rwxr-xr-xscripts/simple_tests.sh20
-rwxr-xr-xscripts/syscall_tests.sh20
-rw-r--r--test/README.md18
-rw-r--r--test/e2e/BUILD (renamed from runsc/test/integration/BUILD)13
-rw-r--r--test/e2e/exec_test.go275
-rw-r--r--test/e2e/integration.go (renamed from runsc/test/integration/integration.go)0
-rw-r--r--test/e2e/integration_test.go (renamed from runsc/test/integration/integration_test.go)46
-rw-r--r--test/e2e/regression_test.go (renamed from runsc/test/integration/regression_test.go)6
-rw-r--r--test/image/BUILD (renamed from runsc/test/image/BUILD)13
-rw-r--r--test/image/image.go (renamed from runsc/test/image/image.go)0
-rw-r--r--test/image/image_test.go (renamed from runsc/test/image/image_test.go)65
-rw-r--r--test/image/latin10k.txt (renamed from runsc/test/image/latin10k.txt)0
-rw-r--r--test/image/mysql.sql (renamed from runsc/test/image/mysql.sql)0
-rw-r--r--test/image/ruby.rb (renamed from runsc/test/image/ruby.rb)0
-rw-r--r--test/image/ruby.sh (renamed from runsc/test/image/ruby.sh)0
-rw-r--r--test/root/BUILD (renamed from runsc/test/root/BUILD)17
-rw-r--r--test/root/cgroup_test.go (renamed from runsc/test/root/cgroup_test.go)29
-rw-r--r--test/root/chroot_test.go (renamed from runsc/test/root/chroot_test.go)27
-rw-r--r--test/root/crictl_test.go (renamed from runsc/test/root/crictl_test.go)97
-rw-r--r--test/root/main_test.go49
-rw-r--r--test/root/oom_score_adj_test.go376
-rw-r--r--test/root/root.go21
-rw-r--r--test/root/testdata/BUILD (renamed from runsc/test/root/testdata/BUILD)3
-rw-r--r--test/root/testdata/busybox.go (renamed from runsc/test/root/testdata/busybox.go)0
-rw-r--r--test/root/testdata/containerd_config.go (renamed from runsc/test/root/testdata/containerd_config.go)0
-rw-r--r--test/root/testdata/httpd.go (renamed from runsc/test/root/testdata/httpd.go)0
-rw-r--r--test/root/testdata/httpd_mount_paths.go (renamed from runsc/test/root/testdata/httpd_mount_paths.go)0
-rw-r--r--test/root/testdata/sandbox.go (renamed from runsc/test/root/testdata/sandbox.go)0
-rw-r--r--test/root/testdata/simple.go41
-rw-r--r--test/runtimes/BUILD58
-rw-r--r--test/runtimes/README.md5
-rw-r--r--test/runtimes/blacklist_go1.12.csv16
-rw-r--r--test/runtimes/blacklist_java11.csv126
-rw-r--r--test/runtimes/blacklist_nodejs12.4.0.csv47
-rw-r--r--test/runtimes/blacklist_php7.3.6.csv29
-rw-r--r--test/runtimes/blacklist_python3.7.3.csv27
-rw-r--r--test/runtimes/blacklist_test.go37
-rw-r--r--test/runtimes/build_defs.bzl57
-rw-r--r--test/runtimes/common/BUILD20
-rw-r--r--test/runtimes/common/common.go114
-rw-r--r--test/runtimes/go/BUILD9
-rw-r--r--test/runtimes/go/Dockerfile34
-rw-r--r--test/runtimes/images/Dockerfile_go1.1210
-rw-r--r--test/runtimes/images/Dockerfile_java1130
-rw-r--r--test/runtimes/images/Dockerfile_nodejs12.4.028
-rw-r--r--test/runtimes/images/Dockerfile_php7.3.627
-rw-r--r--test/runtimes/images/Dockerfile_python3.7.330
-rw-r--r--test/runtimes/images/proctor/BUILD26
-rw-r--r--test/runtimes/images/proctor/go.go (renamed from test/runtimes/go/proctor-go.go)67
-rw-r--r--test/runtimes/images/proctor/java.go (renamed from test/runtimes/java/proctor-java.go)49
-rw-r--r--test/runtimes/images/proctor/nodejs.go46
-rw-r--r--test/runtimes/images/proctor/php.go (renamed from test/runtimes/php/proctor-php.go)36
-rw-r--r--test/runtimes/images/proctor/proctor.go154
-rw-r--r--test/runtimes/images/proctor/proctor_test.go (renamed from test/runtimes/common/common_test.go)13
-rw-r--r--test/runtimes/images/proctor/python.go (renamed from test/runtimes/python/proctor-python.go)34
-rw-r--r--test/runtimes/java/BUILD9
-rw-r--r--test/runtimes/java/Dockerfile35
-rw-r--r--test/runtimes/nodejs/BUILD9
-rw-r--r--test/runtimes/nodejs/Dockerfile30
-rw-r--r--test/runtimes/nodejs/proctor-nodejs.go60
-rw-r--r--test/runtimes/php/BUILD9
-rw-r--r--test/runtimes/php/Dockerfile30
-rw-r--r--test/runtimes/python/BUILD9
-rw-r--r--test/runtimes/python/Dockerfile32
-rw-r--r--test/runtimes/runner.go199
-rwxr-xr-xtest/runtimes/runner.sh35
-rw-r--r--test/runtimes/runtimes_test.go93
-rw-r--r--test/syscalls/BUILD18
-rw-r--r--test/syscalls/build_defs.bzl6
-rw-r--r--test/syscalls/linux/BUILD189
-rw-r--r--test/syscalls/linux/affinity.cc1
-rw-r--r--test/syscalls/linux/aio.cc155
-rw-r--r--test/syscalls/linux/base_poll_test.h2
-rw-r--r--test/syscalls/linux/chown.cc30
-rw-r--r--test/syscalls/linux/clock_nanosleep.cc86
-rw-r--r--test/syscalls/linux/exec.cc50
-rw-r--r--test/syscalls/linux/exec_binary.cc70
-rw-r--r--test/syscalls/linux/fcntl.cc47
-rw-r--r--test/syscalls/linux/futex.cc16
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc26
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h33
-rw-r--r--test/syscalls/linux/itimer.cc4
-rw-r--r--test/syscalls/linux/kill.cc13
-rw-r--r--test/syscalls/linux/link.cc6
-rw-r--r--test/syscalls/linux/mlock.cc38
-rw-r--r--test/syscalls/linux/mremap.cc11
-rw-r--r--test/syscalls/linux/open.cc1
-rw-r--r--test/syscalls/linux/packet_socket.cc29
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc21
-rw-r--r--test/syscalls/linux/pipe.cc14
-rw-r--r--test/syscalls/linux/prctl.cc9
-rw-r--r--test/syscalls/linux/prctl_setuid.cc24
-rw-r--r--test/syscalls/linux/proc.cc13
-rw-r--r--test/syscalls/linux/proc_net.cc5
-rw-r--r--test/syscalls/linux/proc_net_tcp.cc306
-rw-r--r--test/syscalls/linux/proc_net_udp.cc309
-rw-r--r--test/syscalls/linux/proc_net_unix.cc54
-rw-r--r--test/syscalls/linux/ptrace.cc11
-rw-r--r--test/syscalls/linux/pty.cc87
-rw-r--r--test/syscalls/linux/pwritev2.cc16
-rw-r--r--test/syscalls/linux/raw_socket_hdrincl.cc43
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc13
-rw-r--r--test/syscalls/linux/raw_socket_ipv4.cc13
-rw-r--r--test/syscalls/linux/readahead.cc91
-rw-r--r--test/syscalls/linux/semaphore.cc2
-rw-r--r--test/syscalls/linux/sendfile.cc110
-rw-r--r--test/syscalls/linux/signalfd.cc350
-rw-r--r--test/syscalls/linux/sigstop.cc7
-rw-r--r--test/syscalls/linux/socket.cc43
-rw-r--r--test/syscalls/linux/socket_bind_to_device.cc314
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc381
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc316
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.cc75
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.h67
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc2
-rw-r--r--test/syscalls/linux/socket_ip_unbound.cc379
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc408
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc153
-rw-r--r--test/syscalls/linux/socket_netlink_util.cc7
-rw-r--r--test/syscalls/linux/socket_test_util.cc10
-rw-r--r--test/syscalls/linux/socket_test_util.h7
-rw-r--r--test/syscalls/linux/socket_test_util_impl.cc28
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc12
-rw-r--r--test/syscalls/linux/splice.cc194
-rw-r--r--test/syscalls/linux/sticky.cc31
-rw-r--r--test/syscalls/linux/tcp_socket.cc21
-rw-r--r--test/syscalls/linux/timers.cc7
-rw-r--r--test/syscalls/linux/uidgid.cc48
-rw-r--r--test/syscalls/linux/uname.cc14
-rw-r--r--test/syscalls/linux/unlink.cc2
-rw-r--r--test/syscalls/linux/vfork.cc7
-rw-r--r--test/syscalls/syscall_test_runner.go34
-rw-r--r--test/util/BUILD23
-rw-r--r--test/util/fs_util.cc9
-rw-r--r--test/util/fs_util.h3
-rw-r--r--test/util/memory_util.h12
-rw-r--r--test/util/proc_util.cc2
-rw-r--r--test/util/save_util.cc12
-rw-r--r--test/util/save_util.h5
-rw-r--r--test/util/save_util_linux.cc33
-rw-r--r--test/util/save_util_other.cc (renamed from runsc/test/testutil/testutil_race.go)14
-rw-r--r--test/util/test_util.cc4
-rw-r--r--test/util/test_util.h1
-rw-r--r--test/util/thread_util.h18
-rw-r--r--test/util/uid_util.cc44
-rw-r--r--test/util/uid_util.h29
-rw-r--r--third_party/gvsync/downgradable_rwmutex_unsafe.go3
-rwxr-xr-xtools/go_branch.sh6
-rw-r--r--tools/go_marshal/BUILD14
-rw-r--r--tools/go_marshal/README.md164
-rw-r--r--tools/go_marshal/analysis/BUILD13
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go175
-rw-r--r--tools/go_marshal/defs.bzl152
-rw-r--r--tools/go_marshal/gomarshal/BUILD17
-rw-r--r--tools/go_marshal/gomarshal/generator.go382
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go507
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go154
-rw-r--r--tools/go_marshal/gomarshal/util.go387
-rw-r--r--tools/go_marshal/main.go73
-rw-r--r--tools/go_marshal/marshal/BUILD14
-rw-r--r--tools/go_marshal/marshal/marshal.go60
-rw-r--r--tools/go_marshal/test/BUILD31
-rw-r--r--tools/go_marshal/test/benchmark_test.go178
-rw-r--r--tools/go_marshal/test/external/BUILD11
-rw-r--r--tools/go_marshal/test/external/external.go (renamed from runsc/test/root/root.go)13
-rw-r--r--tools/go_marshal/test/test.go105
-rw-r--r--tools/go_stateify/defs.bzl65
-rwxr-xr-xtools/image_build.sh98
-rwxr-xr-xtools/make_repository.sh79
-rwxr-xr-xtools/run_build.sh49
-rwxr-xr-xtools/run_tests.sh302
-rwxr-xr-xtools/workspace_status.sh3
583 files changed, 22953 insertions, 6476 deletions
diff --git a/.bazelrc b/.bazelrc
index eda884473..379fc8328 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -13,7 +13,7 @@
# limitations under the License.
# Display the current git revision in the info block.
-build --workspace_status_command tools/workspace_status.sh
+build --stamp --workspace_status_command tools/workspace_status.sh
# Enable remote execution so actions are performed on the remote systems.
build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com
diff --git a/BUILD b/BUILD
index 60ed992c4..de410b008 100644
--- a/BUILD
+++ b/BUILD
@@ -21,6 +21,9 @@ go_path(
mode = "link",
deps = [
"//runsc",
+
+ # Packages that are not dependencies of //runsc.
+ "//pkg/tcpip/link/channel",
],
)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 638942a42..5d46168bc 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -83,6 +83,8 @@ Rules:
### Code reviews
+Before sending code reviews, run `bazel test ...` to ensure tests are passing.
+
Code changes are accepted via [pull request][github].
When approved, the change will be submitted by a team member and automatically
@@ -100,6 +102,36 @@ form `b/1234`. These correspond to bugs in our internal bug tracker. Eventually
these bugs will be moved to the GitHub Issues, but until then they can simply be
ignored.
+### Build and test with Docker
+
+`scripts/dev.sh` is a convenient script that builds and installs `runsc` as a
+new Docker runtime for you. The scripts tries to extract the runtime name from
+your local environment and will print it at the end. You can also customize it.
+The script creates one regular runtime and another with debug flags enabled.
+Here are a few examples:
+
+```bash
+# Default case (inside branch my-branch)
+$ scripts/dev.sh
+...
+Runtimes my-branch and my-branch-d (debug enabled) setup.
+Use --runtime=my-branch with your Docker command.
+ docker run --rm --runtime=my-branch --rm hello-world
+
+If you rebuild, use scripts/dev.sh --refresh.
+Logs are in: /tmp/my-branch/logs
+
+# --refresh just updates the runtime binary and doesn't restart docker.
+$ git/my_branch> scripts/dev.sh --refresh
+
+# Using a custom runtime name
+$ git/my_branch> scripts/dev.sh my-runtime
+...
+Runtimes my-runtime and my-runtime-d (debug enabled) setup.
+Use --runtime=my-runtime with your Docker command.
+ docker run --rm --runtime=my-runtime --rm hello-world
+```
+
### The small print
Contributions made by corporations are covered by a different agreement than the
diff --git a/Makefile b/Makefile
index 561618478..1735c07df 100644
--- a/Makefile
+++ b/Makefile
@@ -22,7 +22,7 @@ bazel-server-start: docker-build
--privileged \
gvisor-bazel \
sh -c "while :; do sleep 100; done" && \
- docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) gvisor && useradd --uid $(UID) --gid $(GID) -d $(HOME) gvisor"
+ docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --gid $(GID) -d $(HOME) gvisor"
bazel-server:
docker exec gvisor-bazel true || \
diff --git a/README.md b/README.md
index d102845ac..5ac6f9046 100644
--- a/README.md
+++ b/README.md
@@ -48,7 +48,7 @@ Make sure the following dependencies are installed:
* Linux 4.14.77+ ([older linux][old-linux])
* [git][git]
-* [Bazel][bazel] 0.23.0+
+* [Bazel][bazel] 0.28.0+
* [Python][python]
* [Docker version 17.09.0 or greater][docker]
* Gold linker (e.g. `binutils-gold` package on Ubuntu)
@@ -133,11 +133,9 @@ The [gvisor-users mailing list][gvisor-users-list] and
[gvisor-dev mailing list][gvisor-dev-list] are good starting points for
questions and discussion.
-## Security
+## Security Policy
-Sensitive security-related questions, comments and disclosures can be sent to
-the [gvisor-security mailing list][gvisor-security-list]. The full security
-disclosure policy is defined in the [community][community] repository.
+See [SECURITY.md](SECURITY.md).
## Contributing
@@ -147,7 +145,6 @@ See [Contributing.md](CONTRIBUTING.md).
[community]: https://gvisor.googlesource.com/community
[docker]: https://www.docker.com
[git]: https://git-scm.com
-[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security
[gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users
[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev
[oci]: https://www.opencontainers.org
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 000000000..154d68cb3
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,11 @@
+# Security and Vulnerability Reporting
+
+Sensitive security-related questions, comments, and reports should be sent to
+the [gvisor-security mailing list][gvisor-security-list]. You should receive a
+prompt response, typically within 48 hours.
+
+Policies for security list access, vulnerability embargo, and vulnerability
+disclosure are outlined in the [community][community] repository.
+
+[community]: https://gvisor.googlesource.com/community
+[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security
diff --git a/WORKSPACE b/WORKSPACE
index e5c5dfa2b..8f50a3e57 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -3,19 +3,19 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_go",
- sha256 = "313f2c7a23fecc33023563f082f381a32b9b7254f727a7dd2d6380ccc6dfe09b",
+ sha256 = "078f2a9569fa9ed846e60805fb5fb167d6f6c4ece48e6d409bf5fb2154eaf0d8",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.20.0/rules_go-v0.20.0.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.20.0/rules_go-v0.20.0.tar.gz",
],
)
http_archive(
name = "bazel_gazelle",
- sha256 = "be9296bfd64882e3c08e3283c58fcb461fa6dd3c171764fcc4cf322f60615a9b",
+ sha256 = "41bff2a0b32b02f20c227d234aa25ef3783998e5453f7eade929704dcff7cd4b",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz",
- "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.0/bazel-gazelle-v0.19.0.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.0/bazel-gazelle-v0.19.0.tar.gz",
],
)
@@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.12.9",
+ go_version = "1.13.1",
nogo = "@//:nogo",
)
@@ -32,14 +32,26 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
gazelle_dependencies()
-# Load protobuf dependencies.
-load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
+# Load C++ rules.
+http_archive(
+ name = "rules_cc",
+ sha256 = "67412176974bfce3f4cf8bdaff39784a72ed709fc58def599d1f68710b58d68b",
+ strip_prefix = "rules_cc-b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip",
+ "https://github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip",
+ ],
+)
-git_repository(
+# Load protobuf dependencies.
+http_archive(
name = "com_google_protobuf",
- commit = "09745575a923640154bcf307fba8aedff47f240a",
- remote = "https://github.com/protocolbuffers/protobuf",
- shallow_since = "1558721209 -0700",
+ sha256 = "532d2575d8c0992065bb19ec5fba13aa3683499726f6055c11b474f91a00bb0c",
+ strip_prefix = "protobuf-7f520092d9050d96fb4b707ad11a51701af4ce49",
+ urls = [
+ "https://mirror.bazel.build/github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
+ "https://github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
+ ],
)
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
@@ -50,11 +62,11 @@ protobuf_deps()
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "e71eadcfcbdb47b4b740eb48b32ca4226e36aabc425d035a18dd40c2dda808c1",
- strip_prefix = "bazel-toolchains-0.28.4",
+ sha256 = "a019fbd579ce5aed0239de865b2d8281dbb809efd537bf42e0d366783e8dec65",
+ strip_prefix = "bazel-toolchains-0.29.2",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
],
)
@@ -63,6 +75,16 @@ load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig")
rbe_autoconfig(name = "rbe_default")
+http_archive(
+ name = "rules_pkg",
+ sha256 = "5bdc04987af79bd27bc5b00fe30f59a858f77ffa0bd2d8143d5b31ad8b1bd71c",
+ url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.0/rules_pkg-0.2.0.tar.gz",
+)
+
+load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
+
+rules_pkg_dependencies()
+
# External repositories, in sorted order.
go_repository(
name = "com_github_cenkalti_backoff",
@@ -184,7 +206,7 @@ go_repository(
go_repository(
name = "org_golang_x_time",
- commit = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef",
+ commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
)
@@ -210,31 +232,21 @@ go_repository(
# System Call test dependencies.
http_archive(
- name = "com_github_gflags_gflags",
- sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf",
- strip_prefix = "gflags-2.2.2",
- urls = [
- "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.2.tar.gz",
- "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz",
- ],
-)
-
-http_archive(
name = "com_google_absl",
- sha256 = "01ba1185a0e6e048e4890f39e383515195bc335f0627cdddc0c325ee68be4434",
- strip_prefix = "abseil-cpp-cd86d0d20ab167c33b23d3875db68d1d4bad3a3b",
+ sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93",
+ strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/cd86d0d20ab167c33b23d3875db68d1d4bad3a3b.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/cd86d0d20ab167c33b23d3875db68d1d4bad3a3b.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
],
)
http_archive(
name = "com_google_googletest",
- sha256 = "db657310d3c5ca2d3f674e3a4b79718d1d39da70604568ee0568ba8e39065ef4",
- strip_prefix = "googletest-31200def0dec8a624c861f919e86e4444e6e6ee7",
+ sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912",
+ strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34",
urls = [
- "https://mirror.bazel.build/github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz",
- "https://github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz",
+ "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
+ "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
],
)
diff --git a/cloudbuild/go.Dockerfile b/cloudbuild/go.Dockerfile
deleted file mode 100644
index 226442fd2..000000000
--- a/cloudbuild/go.Dockerfile
+++ /dev/null
@@ -1,2 +0,0 @@
-FROM ubuntu
-RUN apt-get -q update && apt-get install -qqy git rsync
diff --git a/cloudbuild/go.yaml b/cloudbuild/go.yaml
deleted file mode 100644
index a38ef71fc..000000000
--- a/cloudbuild/go.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-steps:
-- name: 'gcr.io/cloud-builders/git'
- args: ['fetch', '--all', '--unshallow']
-- name: 'gcr.io/cloud-builders/bazel'
- args: ['build', ':gopath']
-- name: 'gcr.io/cloud-builders/docker'
- args: ['build', '-t', 'gcr.io/$PROJECT_ID/go-branch', '-f', 'cloudbuild/go.Dockerfile', '.']
-- name: 'gcr.io/$PROJECT_ID/go-branch'
- args: ['tools/go_branch.sh']
-- name: 'gcr.io/cloud-builders/git'
- args: ['checkout', 'go']
-- name: 'gcr.io/cloud-builders/git'
- args: ['clean', '-f']
-- name: 'golang'
- args: ['go', 'build', './...']
-- name: 'gcr.io/cloud-builders/git'
- entrypoint: 'bash'
- args:
- - '-c'
- - 'if [[ "$BRANCH_NAME" == "master" ]]; then git push "${_ORIGIN}" go:go; fi'
-substitutions:
- _ORIGIN: origin
diff --git a/kokoro/build.cfg b/kokoro/build.cfg
new file mode 100644
index 000000000..6c1d262d4
--- /dev/null
+++ b/kokoro/build.cfg
@@ -0,0 +1,23 @@
+build_file: "repo/scripts/build.sh"
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-repo-key"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_REPO_KEY"
+ value: "73898_kokoro-repo-key"
+}
+
+action {
+ define_artifacts {
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ regex: "**/dists/**"
+ }
+}
diff --git a/kokoro/build_tests.cfg b/kokoro/build_tests.cfg
new file mode 100644
index 000000000..c64b7e679
--- /dev/null
+++ b/kokoro/build_tests.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/build.sh"
diff --git a/kokoro/common.cfg b/kokoro/common.cfg
index cad873fe1..669a2e458 100644
--- a/kokoro/common.cfg
+++ b/kokoro/common.cfg
@@ -10,7 +10,7 @@ before_action {
# Configure bazel to access RBE.
bazel_setting {
- # Our GCP project name
+ # Our GCP project name.
project_id: "gvisor-rbe"
# Use RBE for execution as well as caching.
diff --git a/kokoro/continuous.cfg b/kokoro/continuous.cfg
deleted file mode 100644
index 8da47736a..000000000
--- a/kokoro/continuous.cfg
+++ /dev/null
@@ -1,13 +0,0 @@
-# Location of bash script that runs the test. The first directory in the path
-# is the directory where Kokoro will check out the repo. The rest is the path
-# is the path to the test script.
-build_file: "repo/kokoro/run_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc-logs.tar.gz"
- }
-}
diff --git a/kokoro/do_tests.cfg b/kokoro/do_tests.cfg
new file mode 100644
index 000000000..b45ec0b42
--- /dev/null
+++ b/kokoro/do_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/do_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/docker_tests.cfg b/kokoro/docker_tests.cfg
new file mode 100644
index 000000000..0a0ef87ed
--- /dev/null
+++ b/kokoro/docker_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/docker_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/go.cfg b/kokoro/go.cfg
new file mode 100644
index 000000000..b9c1fcb12
--- /dev/null
+++ b/kokoro/go.cfg
@@ -0,0 +1,20 @@
+build_file: "repo/scripts/go.sh"
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-github-access-token"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_GITHUB_ACCESS_TOKEN"
+ value: "73898_kokoro-github-access-token"
+}
+
+env_vars {
+ key: "KOKORO_GO_PUSH"
+ value: "true"
+}
diff --git a/kokoro/go_tests.cfg b/kokoro/go_tests.cfg
new file mode 100644
index 000000000..5eb51041a
--- /dev/null
+++ b/kokoro/go_tests.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/go.sh"
diff --git a/kokoro/hostnet_tests.cfg b/kokoro/hostnet_tests.cfg
new file mode 100644
index 000000000..520dc55a3
--- /dev/null
+++ b/kokoro/hostnet_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/hostnet_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/kvm_tests.cfg b/kokoro/kvm_tests.cfg
new file mode 100644
index 000000000..1feb60c8a
--- /dev/null
+++ b/kokoro/kvm_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/kvm_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/make_tests.cfg b/kokoro/make_tests.cfg
new file mode 100644
index 000000000..d973130ff
--- /dev/null
+++ b/kokoro/make_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/make_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/overlay_tests.cfg b/kokoro/overlay_tests.cfg
new file mode 100644
index 000000000..6a2ddbd03
--- /dev/null
+++ b/kokoro/overlay_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/overlay_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/presubmit.cfg b/kokoro/presubmit.cfg
deleted file mode 100644
index 8da47736a..000000000
--- a/kokoro/presubmit.cfg
+++ /dev/null
@@ -1,13 +0,0 @@
-# Location of bash script that runs the test. The first directory in the path
-# is the directory where Kokoro will check out the repo. The rest is the path
-# is the path to the test script.
-build_file: "repo/kokoro/run_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc-logs.tar.gz"
- }
-}
diff --git a/kokoro/release-nightly.cfg b/kokoro/release-nightly.cfg
deleted file mode 100644
index e5087b1cd..000000000
--- a/kokoro/release-nightly.cfg
+++ /dev/null
@@ -1,11 +0,0 @@
-# Location of bash script that builds a release.
-build_file: "repo/kokoro/run_build.sh"
-
-action {
- # Upload runsc binary and its checksum. It may be in multiple paths, so we
- # must use the wildcard.
- define_artifacts {
- regex: "**/runsc"
- regex: "**/runsc.sha512"
- }
-}
diff --git a/kokoro/release.cfg b/kokoro/release.cfg
new file mode 100644
index 000000000..b9d35bc51
--- /dev/null
+++ b/kokoro/release.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/release.sh"
diff --git a/kokoro/root_tests.cfg b/kokoro/root_tests.cfg
new file mode 100644
index 000000000..28351695c
--- /dev/null
+++ b/kokoro/root_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/root_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/run_build.sh b/kokoro/run_build.sh
deleted file mode 120000
index 9deafe9bb..000000000
--- a/kokoro/run_build.sh
+++ /dev/null
@@ -1 +0,0 @@
-../tools/run_build.sh \ No newline at end of file
diff --git a/kokoro/run_tests.sh b/kokoro/run_tests.sh
deleted file mode 120000
index 931cd2622..000000000
--- a/kokoro/run_tests.sh
+++ /dev/null
@@ -1 +0,0 @@
-../tools/run_tests.sh \ No newline at end of file
diff --git a/kokoro/simple_tests.cfg b/kokoro/simple_tests.cfg
new file mode 100644
index 000000000..32e0a9431
--- /dev/null
+++ b/kokoro/simple_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/simple_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/syscall_tests.cfg b/kokoro/syscall_tests.cfg
new file mode 100644
index 000000000..ee6e4a3a4
--- /dev/null
+++ b/kokoro/syscall_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/syscall_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/ubuntu1604/10_core.sh b/kokoro/ubuntu1604/10_core.sh
new file mode 100755
index 000000000..e87a6eee8
--- /dev/null
+++ b/kokoro/ubuntu1604/10_core.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Install all essential build tools.
+apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config
+
+# Install a recent go toolchain.
+if ! [[ -d /usr/local/go ]]; then
+ wget https://dl.google.com/go/go1.12.linux-amd64.tar.gz
+ tar -xvf go1.12.linux-amd64.tar.gz
+ mv go /usr/local
+fi
+
+# Link the Go binary from /usr/bin; replacing anything there.
+(cd /usr/bin && rm -f go && sudo ln -fs /usr/local/go/bin/go go)
diff --git a/kokoro/ubuntu1604/20_bazel.sh b/kokoro/ubuntu1604/20_bazel.sh
new file mode 100755
index 000000000..b9a894024
--- /dev/null
+++ b/kokoro/ubuntu1604/20_bazel.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+declare -r BAZEL_VERSION=0.29.1
+
+# Install bazel dependencies.
+apt-get update && apt-get install -y openjdk-8-jdk-headless unzip
+
+# Use the release installer.
+curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/kokoro/ubuntu1604/25_docker.sh b/kokoro/ubuntu1604/25_docker.sh
new file mode 100755
index 000000000..1d3defcd3
--- /dev/null
+++ b/kokoro/ubuntu1604/25_docker.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Add dependencies.
+apt-get update && apt-get -y install \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common
+
+# Install the key.
+curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
+
+# Add the repository.
+add-apt-repository \
+ "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
+ $(lsb_release -cs) \
+ stable"
+
+# Install docker.
+apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io
diff --git a/kokoro/ubuntu1604/30_containerd.sh b/kokoro/ubuntu1604/30_containerd.sh
new file mode 100755
index 000000000..a7472bd1c
--- /dev/null
+++ b/kokoro/ubuntu1604/30_containerd.sh
@@ -0,0 +1,76 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Helper for Go packages below.
+install_helper() {
+ PACKAGE="${1}"
+ TAG="${2}"
+ GOPATH="${3}"
+
+ # Clone the repository.
+ mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \
+ git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}"
+
+ # Checkout and build the repository.
+ (cd "${GOPATH}"/src/"${PACKAGE}" && \
+ git checkout "${TAG}" && \
+ GOPATH="${GOPATH}" make && \
+ GOPATH="${GOPATH}" make install)
+}
+
+# Install dependencies for the crictl tests.
+apt-get install -y btrfs-tools libseccomp-dev
+
+# Install containerd & cri-tools.
+GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
+install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}"
+install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}"
+
+# Install gvisor-containerd-shim.
+declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
+declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
+declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
+wget --no-verbose "${base}"/latest -O ${latest}
+wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
+chmod +x ${shim_path}
+mv ${shim_path} /usr/local/bin
+
+# Configure containerd-shim.
+declare -r shim_config_path=/etc/containerd
+declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml)
+mkdir -p ${shim_config_path}
+cat > ${shim_config_tmp_path} <<-EOF
+ runc_shim = "/usr/local/bin/containerd-shim"
+
+[runsc_config]
+ debug = "true"
+ debug-log = "/tmp/runsc-logs/"
+ strace = "true"
+ file-access = "shared"
+EOF
+mv ${shim_config_tmp_path} ${shim_config_path}
+
+# Configure CNI.
+(cd "${GOPATH}" && GOPATH="${GOPATH}" \
+ src/github.com/containerd/containerd/script/setup/install-cni)
+
+# Cleanup the above.
+rm -rf "${GOPATH}"
+rm -rf "${latest}"
+rm -rf "${shim_path}"
+rm -rf "${shim_config_tmp_path}"
diff --git a/kokoro/ubuntu1604/40_kokoro.sh b/kokoro/ubuntu1604/40_kokoro.sh
new file mode 100755
index 000000000..64772d74d
--- /dev/null
+++ b/kokoro/ubuntu1604/40_kokoro.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Declare kokoro's required public keys.
+declare -r ssh_public_keys=(
+ "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com"
+ "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com"
+)
+
+# Install dependencies.
+apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm
+
+# We need a kbuilder user.
+if useradd -c "kbuilder user" -m -s /bin/bash kbuilder; then
+ # User was added successfully; we add the relevant SSH keys here.
+ mkdir -p ~kbuilder/.ssh
+ (IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys
+ chmod 0600 ~kbuilder/.ssh/authorized_keys
+ chown -R kbuilder ~kbuilder/.ssh
+fi
+
+# Give passwordless sudo access.
+cat > /etc/sudoers.d/kokoro <<EOF
+kbuilder ALL=(ALL) NOPASSWD:ALL
+EOF
+
+# Ensure we can run Docker without sudo.
+usermod -aG docker kbuilder
+
+# Ensure that we can access kvm.
+usermod -aG kvm kbuilder
+
+# Ensure that /tmpfs exists and is writable by kokoro.
+#
+# Note that kokoro will typically attach a second disk (sdb) to the instance
+# that is used for the /tmpfs volume. In the future we could setup an init
+# script that formats and mounts this here; however, we don't expect our build
+# artifacts to be that large.
+mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY
diff --git a/kokoro/ubuntu1604/build.sh b/kokoro/ubuntu1604/build.sh
new file mode 100755
index 000000000..d664a3a76
--- /dev/null
+++ b/kokoro/ubuntu1604/build.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Run the image_build.sh script with appropriate parameters.
+IMAGE_PROJECT=ubuntu-os-cloud IMAGE_FAMILY=ubuntu-1604-lts $(dirname $0)/../../tools/image_build.sh $(dirname $0)/??_*.sh
diff --git a/kokoro/ubuntu1804/10_core.sh b/kokoro/ubuntu1804/10_core.sh
new file mode 120000
index 000000000..6facceeee
--- /dev/null
+++ b/kokoro/ubuntu1804/10_core.sh
@@ -0,0 +1 @@
+../ubuntu1604/10_core.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/20_bazel.sh b/kokoro/ubuntu1804/20_bazel.sh
new file mode 120000
index 000000000..39194c0f5
--- /dev/null
+++ b/kokoro/ubuntu1804/20_bazel.sh
@@ -0,0 +1 @@
+../ubuntu1604/20_bazel.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/25_docker.sh b/kokoro/ubuntu1804/25_docker.sh
new file mode 120000
index 000000000..63269bd83
--- /dev/null
+++ b/kokoro/ubuntu1804/25_docker.sh
@@ -0,0 +1 @@
+../ubuntu1604/25_docker.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/30_containerd.sh b/kokoro/ubuntu1804/30_containerd.sh
new file mode 120000
index 000000000..6ac2377ed
--- /dev/null
+++ b/kokoro/ubuntu1804/30_containerd.sh
@@ -0,0 +1 @@
+../ubuntu1604/30_containerd.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/40_kokoro.sh b/kokoro/ubuntu1804/40_kokoro.sh
new file mode 120000
index 000000000..e861fb5e1
--- /dev/null
+++ b/kokoro/ubuntu1804/40_kokoro.sh
@@ -0,0 +1 @@
+../ubuntu1604/40_kokoro.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/build.sh b/kokoro/ubuntu1804/build.sh
new file mode 100755
index 000000000..2b5c9a6f2
--- /dev/null
+++ b/kokoro/ubuntu1804/build.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Run the image_build.sh script with appropriate parameters.
+IMAGE_PROJECT=ubuntu-os-cloud IMAGE_FAMILY=ubuntu-1804-lts $(dirname $0)/../../tools/image_build.sh $(dirname $0)/??_*.sh
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index ba233b93f..f45934466 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -2,9 +2,11 @@
# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
# when the host OS may not be Linux.
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "linux",
@@ -44,6 +46,7 @@ go_library(
"sem.go",
"shm.go",
"signal.go",
+ "signalfd.go",
"socket.go",
"splice.go",
"tcp.go",
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index 7d742871a..257f67222 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -271,7 +271,7 @@ type Statx struct {
}
// FileMode represents a mode_t.
-type FileMode uint
+type FileMode uint16
// Permissions returns just the permission bits.
func (m FileMode) Permissions() FileMode {
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
new file mode 100644
index 000000000..85fad9956
--- /dev/null
+++ b/pkg/abi/linux/signalfd.go
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // SFD_NONBLOCK is a signalfd(2) flag.
+ SFD_NONBLOCK = 00004000
+
+ // SFD_CLOEXEC is a signalfd(2) flag.
+ SFD_CLOEXEC = 02000000
+)
+
+// SignalfdSiginfo is the siginfo encoding for signalfds.
+type SignalfdSiginfo struct {
+ Signo uint32
+ Errno int32
+ Code int32
+ PID uint32
+ UID uint32
+ FD int32
+ TID uint32
+ Band uint32
+ Overrun uint32
+ TrapNo uint32
+ Status int32
+ Int int32
+ Ptr uint64
+ UTime uint64
+ STime uint64
+ Addr uint64
+ AddrLSB uint16
+ _ [48]uint8
+}
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 39d253b98..6bc486b62 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 47ab65346..36beaade9 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -7,6 +8,7 @@ go_library(
srcs = [
"atomic_bitops.go",
"atomic_bitops_amd64.s",
+ "atomic_bitops_arm64.s",
"atomic_bitops_common.go",
],
importpath = "gvisor.dev/gvisor/pkg/atomicbitops",
diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomic_bitops.go
index 63aa2b7f1..fcc41a9ea 100644
--- a/pkg/atomicbitops/atomic_bitops.go
+++ b/pkg/atomicbitops/atomic_bitops.go
@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64
+// +build amd64 arm64
// Package atomicbitops provides basic bitwise operations in an atomic way.
// The implementation on amd64 leverages the LOCK prefix directly instead of
-// relying on the generic cas primitives.
+// relying on the generic cas primitives, and the arm64 leverages the LDAXR
+// and STLXR pair primitives.
//
// WARNING: the bitwise ops provided in this package doesn't imply any memory
// ordering. Using them to construct locks must employ proper memory barriers.
diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomic_bitops_arm64.s
new file mode 100644
index 000000000..97f8808c1
--- /dev/null
+++ b/pkg/atomicbitops/atomic_bitops_arm64.s
@@ -0,0 +1,139 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+#include "textflag.h"
+
+TEXT ·AndUint32(SB),$0-12
+ MOVD ptr+0(FP), R0
+ MOVW val+8(FP), R1
+again:
+ LDAXRW (R0), R2
+ ANDW R1, R2
+ STLXRW R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·OrUint32(SB),$0-12
+ MOVD ptr+0(FP), R0
+ MOVW val+8(FP), R1
+again:
+ LDAXRW (R0), R2
+ ORRW R1, R2
+ STLXRW R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·XorUint32(SB),$0-12
+ MOVD ptr+0(FP), R0
+ MOVW val+8(FP), R1
+again:
+ LDAXRW (R0), R2
+ EORW R1, R2
+ STLXRW R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·CompareAndSwapUint32(SB),$0-20
+ MOVD addr+0(FP), R0
+ MOVW old+8(FP), R1
+ MOVW new+12(FP), R2
+
+again:
+ LDAXRW (R0), R3
+ CMPW R1, R3
+ BNE done
+ STLXRW R2, (R0), R4
+ CBNZ R4, again
+done:
+ MOVW R3, prev+16(FP)
+ RET
+
+TEXT ·AndUint64(SB),$0-16
+ MOVD ptr+0(FP), R0
+ MOVD val+8(FP), R1
+again:
+ LDAXR (R0), R2
+ AND R1, R2
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·OrUint64(SB),$0-16
+ MOVD ptr+0(FP), R0
+ MOVD val+8(FP), R1
+again:
+ LDAXR (R0), R2
+ ORR R1, R2
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·XorUint64(SB),$0-16
+ MOVD ptr+0(FP), R0
+ MOVD val+8(FP), R1
+again:
+ LDAXR (R0), R2
+ EOR R1, R2
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ RET
+
+TEXT ·CompareAndSwapUint64(SB),$0-32
+ MOVD addr+0(FP), R0
+ MOVD old+8(FP), R1
+ MOVD new+16(FP), R2
+
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE done
+ STLXR R2, (R0), R4
+ CBNZ R4, again
+done:
+ MOVD R3, prev+24(FP)
+ RET
+
+TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
+ MOVD addr+0(FP), R0
+
+again:
+ LDAXRW (R0), R1
+ CBZ R1, fail
+ ADDW $1, R1
+ STLXRW R1, (R0), R2
+ CBNZ R2, again
+ MOVW $1, R2
+ MOVB R2, ret+8(FP)
+ RET
+fail:
+ MOVB ZR, ret+8(FP)
+ RET
+
+TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
+ MOVD addr+0(FP), R0
+
+again:
+ LDAXRW (R0), R1
+ SUBSW $1, R1, R1
+ BEQ fail
+ STLXRW R1, (R0), R2
+ CBNZ R2, again
+ MOVW $1, R2
+ MOVB R2, ret+8(FP)
+ RET
+fail:
+ MOVB ZR, ret+8(FP)
+ RET
diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomic_bitops_common.go
index b2a943dcb..85163ad62 100644
--- a/pkg/atomicbitops/atomic_bitops_common.go
+++ b/pkg/atomicbitops/atomic_bitops_common.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !amd64
+// +build !amd64,!arm64
package atomicbitops
diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD
index 09d6c2c1f..543fb54bf 100644
--- a/pkg/binary/BUILD
+++ b/pkg/binary/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
index 0c2dde4f8..1b5dac99a 100644
--- a/pkg/bits/BUILD
+++ b/pkg/bits/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -10,8 +11,9 @@ go_library(
"bits.go",
"bits32.go",
"bits64.go",
- "uint64_arch_amd64.go",
+ "uint64_arch.go",
"uint64_arch_amd64_asm.s",
+ "uint64_arch_arm64_asm.s",
"uint64_arch_generic.go",
],
importpath = "gvisor.dev/gvisor/pkg/bits",
diff --git a/pkg/bits/uint64_arch_amd64.go b/pkg/bits/uint64_arch.go
index faccaa61a..9f23eff77 100644
--- a/pkg/bits/uint64_arch_amd64.go
+++ b/pkg/bits/uint64_arch.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64
+// +build amd64 arm64
package bits
diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s
new file mode 100644
index 000000000..814ba562d
--- /dev/null
+++ b/pkg/bits/uint64_arch_arm64_asm.s
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+TEXT ·TrailingZeros64(SB),$0-16
+ MOVD x+0(FP), R0
+ RBIT R0, R0
+ CLZ R0, R0 // return 64 if x == 0
+ MOVD R0, ret+8(FP)
+ RET
+
+TEXT ·MostSignificantOne64(SB),$0-16
+ MOVD x+0(FP), R0
+ CLZ R0, R0 // return 64 if x == 0
+ MOVD $63, R1
+ SUBS R0, R1, R0 // ret = 63 - CLZ
+ BPL end
+ MOVD $64, R0 // x == 0
+end:
+ MOVD R0, ret+8(FP)
+ RET
diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go
index 7dd2d1480..9dd2098d1 100644
--- a/pkg/bits/uint64_arch_generic.go
+++ b/pkg/bits/uint64_arch_generic.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !amd64
+// +build !amd64,!arm64
package bits
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index b692aa3b1..8d31e068c 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "bpf",
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index cdec96df1..a0b21d4bd 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index 830e19e07..32422f9e2 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "cpuid",
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 3fabaf445..5d61dc2ff 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -418,6 +418,73 @@ var x86FeatureParseOnlyStrings = map[Feature]string{
X86FeaturePREFETCHWT1: "prefetchwt1",
}
+// intelCacheDescriptors describe the caches and TLBs on the system. They are
+// returned in the registers for eax=2. Intel only.
+type intelCacheDescriptor uint8
+
+// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
+// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
+const (
+ intelNullDescriptor intelCacheDescriptor = 0
+ intelNoTLBDescriptor intelCacheDescriptor = 0xfe
+ intelNoCacheDescriptor intelCacheDescriptor = 0xff
+
+ // Most descriptors omitted for brevity as they are currently unused.
+)
+
+// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
+type CacheType uint8
+
+const (
+ // cacheNull indicates that there are no more entries.
+ cacheNull CacheType = iota
+
+ // CacheData is a data cache.
+ CacheData
+
+ // CacheInstruction is an instruction cache.
+ CacheInstruction
+
+ // CacheUnified is a unified instruction and data cache.
+ CacheUnified
+)
+
+// Cache describes the parameters of a single cache on the system.
+//
+// +stateify savable
+type Cache struct {
+ // Level is the hierarchical level of this cache (L1, L2, etc).
+ Level uint32
+
+ // Type is the type of cache.
+ Type CacheType
+
+ // FullyAssociative indicates that entries may be placed in any block.
+ FullyAssociative bool
+
+ // Partitions is the number of physical partitions in the cache.
+ Partitions uint32
+
+ // Ways is the number of ways of associativity in the cache.
+ Ways uint32
+
+ // Sets is the number of sets in the cache.
+ Sets uint32
+
+ // InvalidateHierarchical indicates that WBINVD/INVD from threads
+ // sharing this cache acts upon lower level caches for threads sharing
+ // this cache.
+ InvalidateHierarchical bool
+
+ // Inclusive indicates that this cache is inclusive of lower cache
+ // levels.
+ Inclusive bool
+
+ // DirectMapped indicates that this cache is directly mapped from
+ // address, rather than using a hash function.
+ DirectMapped bool
+}
+
// Just a way to wrap cpuid function numbers.
type cpuidFunction uint32
@@ -494,7 +561,7 @@ func (f Feature) flagString(cpuinfoOnly bool) string {
return ""
}
-// FeatureSet is a set of Features for a cpu.
+// FeatureSet is a set of Features for a CPU.
//
// +stateify savable
type FeatureSet struct {
@@ -521,6 +588,15 @@ type FeatureSet struct {
// SteppingID is part of the processor signature.
SteppingID uint8
+
+ // Caches describes the caches on the CPU.
+ Caches []Cache
+
+ // CacheLine is the size of a cache line in bytes.
+ //
+ // All caches use the same line size. This is not enforced in the CPUID
+ // encoding, but is true on all known x86 processors.
+ CacheLine uint32
}
// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
@@ -557,22 +633,27 @@ func (fs FeatureSet) CPUInfo(cpu uint) string {
fmt.Fprintln(&b, "wp\t\t: yes")
fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(&b, "clflush size\t: %d\n", 64)
- fmt.Fprintf(&b, "cache_alignment\t: %d\n", 64)
+ fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine)
fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
return b.String()
}
+const (
+ amdVendorID = "AuthenticAMD"
+ intelVendorID = "GenuineIntel"
+)
+
// AMD returns true if fs describes an AMD CPU.
func (fs *FeatureSet) AMD() bool {
- return fs.VendorID == "AuthenticAMD"
+ return fs.VendorID == amdVendorID
}
// Intel returns true if fs describes an Intel CPU.
func (fs *FeatureSet) Intel() bool {
- return fs.VendorID == "GenuineIntel"
+ return fs.VendorID == intelVendorID
}
// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
@@ -589,9 +670,18 @@ func (e ErrIncompatible) Error() string {
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
func (fs *FeatureSet) CheckHostCompatible() error {
hfs := HostFeatureSet()
+
if diff := fs.Subtract(hfs); diff != nil {
return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
}
+
+ // The size of a cache line must match, as it is critical to correctly
+ // utilizing CLFLUSH. Other cache properties are allowed to change, as
+ // they are not important to correctness.
+ if fs.CacheLine != hfs.CacheLine {
+ return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
+ }
+
return nil
}
@@ -732,14 +822,6 @@ func (fs *FeatureSet) HasFeature(feature Feature) bool {
return fs.Set[feature]
}
-// IsSubset returns true if the FeatureSet is a subset of the FeatureSet passed in.
-// This is useful if you want to see if a FeatureSet is compatible with another
-// FeatureSet, since you can only run with a given FeatureSet if it's a subset of
-// the host's.
-func (fs *FeatureSet) IsSubset(other *FeatureSet) bool {
- return fs.Subtract(other) == nil
-}
-
// Subtract returns the features present in fs that are not present in other.
// If all features in fs are present in other, Subtract returns nil.
func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
@@ -755,17 +837,6 @@ func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
return
}
-// TakeFeatureIntersection will set the features in `fs` to the intersection of
-// the features in `fs` and `other` (effectively clearing any feature bits on
-// `fs` that are not also set in `other`).
-func (fs *FeatureSet) TakeFeatureIntersection(other *FeatureSet) {
- for f := range fs.Set {
- if !other.Set[f] {
- delete(fs.Set, f)
- }
- }
-}
-
// EmulateID emulates a cpuid instruction based on the feature set.
func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
switch cpuidFunction(origAx) {
@@ -773,9 +844,8 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
bx, dx, cx = fs.vendorIDRegs()
case featureInfo:
- // clflush line size (ebx bits[15:8]) hardcoded as 8. This
- // means cache lines of size 64 bytes.
- bx = 8 << 8
+ // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
+ bx = (fs.CacheLine / 8) << 8
cx = fs.blockMask(block(0))
dx = fs.blockMask(block(1))
ax = fs.signature()
@@ -789,10 +859,46 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
// will always return 01H. Software should ignore this value
// and not interpret it as an informational descriptor." - SDM
//
- // We do not support exposing cache information, but we do set
- // this fixed field because some language runtimes (dlang) get
- // confused by ax = 0 and will loop infinitely.
- ax = 1
+ // We only support reporting cache parameters via
+ // intelDeterministicCacheParams; report as much here.
+ //
+ // We do not support exposing TLB information at all.
+ ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
+ case intelDeterministicCacheParams:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // cx is the index of the cache to describe.
+ if int(origCx) >= len(fs.Caches) {
+ return uint32(cacheNull), 0, 0, 0
+ }
+ c := fs.Caches[origCx]
+
+ ax = uint32(c.Type)
+ ax |= c.Level << 5
+ ax |= 1 << 8 // Always claim the cache is "self-initializing".
+ if c.FullyAssociative {
+ ax |= 1 << 9
+ }
+ // Processor topology not supported.
+
+ bx = fs.CacheLine - 1
+ bx |= (c.Partitions - 1) << 12
+ bx |= (c.Ways - 1) << 22
+
+ cx = c.Sets - 1
+
+ if !c.InvalidateHierarchical {
+ dx |= 1
+ }
+ if c.Inclusive {
+ dx |= 1 << 1
+ }
+ if !c.DirectMapped {
+ dx |= 1 << 2
+ }
case xSaveInfo:
if !fs.UseXsave() {
return 0, 0, 0, 0
@@ -845,10 +951,41 @@ func HostFeatureSet() *FeatureSet {
vendorID := vendorIDFromRegs(bx, cx, dx)
// eax=1 gets basic features in ecx:edx.
- ax, _, cx, dx := HostID(1, 0)
+ ax, bx, cx, dx := HostID(1, 0)
featureBlock0 := cx
featureBlock1 := dx
ef, em, pt, f, m, sid := signatureSplit(ax)
+ cacheLine := 8 * (bx >> 8) & 0xff
+
+ // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
+ var caches []Cache
+ if vendorID == intelVendorID {
+ // ecx selects the cache index until a null type is returned.
+ for i := uint32(0); ; i++ {
+ ax, bx, cx, dx := HostID(4, i)
+ t := CacheType(ax & 0xf)
+ if t == cacheNull {
+ break
+ }
+
+ lineSize := (bx & 0xfff) + 1
+ if lineSize != cacheLine {
+ panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
+ }
+
+ caches = append(caches, Cache{
+ Type: t,
+ Level: (ax >> 5) & 0x7,
+ FullyAssociative: ((ax >> 9) & 1) == 1,
+ Partitions: ((bx >> 12) & 0x3ff) + 1,
+ Ways: ((bx >> 22) & 0x3ff) + 1,
+ Sets: cx + 1,
+ InvalidateHierarchical: (dx & 1) == 0,
+ Inclusive: ((dx >> 1) & 1) == 1,
+ DirectMapped: ((dx >> 2) & 1) == 0,
+ })
+ }
+ }
// eax=7, ecx=0 gets extended features in ecx:ebx.
_, bx, cx, _ = HostID(7, 0)
@@ -883,6 +1020,8 @@ func HostFeatureSet() *FeatureSet {
Family: f,
Model: m,
SteppingID: sid,
+ CacheLine: cacheLine,
+ Caches: caches,
}
}
diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_test.go
index 6ae14d2da..a707ebb55 100644
--- a/pkg/cpuid/cpuid_test.go
+++ b/pkg/cpuid/cpuid_test.go
@@ -57,24 +57,13 @@ var justFPUandPAE = &FeatureSet{
X86FeaturePAE: true,
}}
-func TestIsSubset(t *testing.T) {
- if !justFPU.IsSubset(justFPUandPAE) {
- t.Errorf("Got %v is not subset of %v, want IsSubset being true", justFPU, justFPUandPAE)
+func TestSubtract(t *testing.T) {
+ if diff := justFPU.Subtract(justFPUandPAE); diff != nil {
+ t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff)
}
- if justFPUandPAE.IsSubset(justFPU) {
- t.Errorf("Got %v is a subset of %v, want IsSubset being false", justFPU, justFPUandPAE)
- }
-}
-
-func TestTakeFeatureIntersection(t *testing.T) {
- testFeatures := HostFeatureSet()
- testFeatures.TakeFeatureIntersection(justFPU)
- if !testFeatures.IsSubset(justFPU) {
- t.Errorf("Got more features than expected after intersecting host features with justFPU: %v, want %v", testFeatures.Set, justFPU.Set)
- }
- if !testFeatures.HasFeature(X86FeatureFPU) {
- t.Errorf("Got no features in testFeatures after intersecting, want %v", X86FeatureFPU)
+ if justFPUandPAE.Subtract(justFPU) == nil {
+ t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE)
}
}
@@ -83,7 +72,7 @@ func TestTakeFeatureIntersection(t *testing.T) {
// if HostFeatureSet gives back junk bits.
func TestHostFeatureSet(t *testing.T) {
hostFeatures := HostFeatureSet()
- if !justFPUandPAE.IsSubset(hostFeatures) {
+ if justFPUandPAE.Subtract(hostFeatures) != nil {
t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
}
}
@@ -175,6 +164,7 @@ func TestEmulateIDBasicFeatures(t *testing.T) {
testFeatures := newEmptyFeatureSet()
testFeatures.Add(X86FeatureCLFSH)
testFeatures.Add(X86FeatureAVX)
+ testFeatures.CacheLine = 64
ax, bx, cx, dx := testFeatures.EmulateID(1, 0)
ECXAVXBit := uint32(1 << uint(X86FeatureAVX))
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 9961baaa9..71f2abc83 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index 785c685a0..c7f549428 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -7,6 +8,9 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/unet",
+ ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 83bcfe220..7691b477b 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,6 +22,8 @@ import (
"runtime"
"sync/atomic"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -185,6 +187,12 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
+// DialUnix connects to a Unix Domain Socket and return the file descriptor.
+func DialUnix(path string) (*FD, error) {
+ socket, err := unet.Connect(path, false)
+ return New(socket.FD()), err
+}
+
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
index e54e7371c..56495cbd9 100644
--- a/pkg/fdchannel/BUILD
+++ b/pkg/fdchannel/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index bd1d614b6..5643d5f26 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -18,6 +19,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
+ "//third_party/gvsync",
],
)
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index 865b6f640..8390915a2 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error {
*ep.dataLen() = w.Len()
// Return control to the client.
+ raceBecomeInactive()
if err := ep.futexSwitchToPeer(); err != nil {
return err
}
@@ -121,7 +122,16 @@ func (ep *Endpoint) enterFutexWait() error {
}
func (ep *Endpoint) exitFutexWait() {
- atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
+ switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps {
+ case 0:
+ return
+ case epsShutdown:
+ // ep.ctrlShutdown() was called while we were blocked, so we are
+ // repsonsible for indicating connection shutdown.
+ ep.shutdownConn()
+ default:
+ panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked))
+ }
}
func (ep *Endpoint) ctrlShutdown() {
@@ -142,5 +152,25 @@ func (ep *Endpoint) ctrlShutdown() {
break
}
}
+ } else {
+ // There is no blocked thread, so we are responsible for indicating
+ // connection shutdown.
+ ep.shutdownConn()
+ }
+}
+
+func (ep *Endpoint) shutdownConn() {
+ switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs {
+ case ep.activeState:
+ if err := ep.futexWakeConnState(1); err != nil {
+ log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err)
+ }
+ case ep.inactiveState:
+ // The peer is currently active and will detect shutdown when it tries
+ // to update the connection state.
+ case csShutdown:
+ // The peer also called Endpoint.Shutdown().
+ default:
+ log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs)
}
}
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 5c9212c33..386cee42c 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -42,11 +42,6 @@ type Endpoint struct {
// dataCap is immutable.
dataCap uint32
- // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
- // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
- // accessed using atomic memory operations.
- shutdown uint32
-
// activeState is csClientActive if this is a client Endpoint and
// csServerActive if this is a server Endpoint.
activeState uint32
@@ -55,9 +50,27 @@ type Endpoint struct {
// csClientActive if this is a server Endpoint.
inactiveState uint32
+ // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
+ // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
+ // accessed using atomic memory operations.
+ shutdown uint32
+
ctrl endpointControlImpl
}
+// EndpointSide indicates which side of a connection an Endpoint belongs to.
+type EndpointSide int
+
+const (
+ // ClientSide indicates that an Endpoint is a client (initially-active;
+ // first method call should be Connect).
+ ClientSide EndpointSide = iota
+
+ // ServerSide indicates that an Endpoint is a server (initially-inactive;
+ // first method call should be RecvFirst.)
+ ServerSide
+)
+
// Init must be called on zero-value Endpoints before first use. If it
// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use.
//
@@ -65,7 +78,17 @@ type Endpoint struct {
// Endpoint. FD may differ between Endpoints if they are in different
// processes, but must represent the same file. The packet window must
// initially be filled with zero bytes.
-func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+ switch side {
+ case ClientSide:
+ ep.activeState = csClientActive
+ ep.inactiveState = csServerActive
+ case ServerSide:
+ ep.activeState = csServerActive
+ ep.inactiveState = csClientActive
+ default:
+ return fmt.Errorf("invalid EndpointSide: %v", side)
+ }
if pwd.Length < pageSize {
return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize)
}
@@ -78,9 +101,6 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
}
ep.packet = m
ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes)
- // These will be overwritten by ep.Connect() for client Endpoints.
- ep.activeState = csServerActive
- ep.inactiveState = csClientActive
if err := ep.ctrlInit(opts...); err != nil {
ep.unmapPacket()
return err
@@ -90,9 +110,9 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
// NewEndpoint is a convenience function that returns an initialized Endpoint
// allocated on the heap.
-func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
+func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
var ep Endpoint
- if err := ep.Init(pwd, opts...); err != nil {
+ if err := ep.Init(side, pwd, opts...); err != nil {
return nil, err
}
return &ep, nil
@@ -115,9 +135,9 @@ func (ep *Endpoint) unmapPacket() {
}
// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
-// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not
-// wait for concurrent calls to return. The effect of Shutdown on the peer
-// Endpoint is unspecified. Successive calls to Shutdown have no effect.
+// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
+// Endpoint, to unblock and return errors. It does not wait for concurrent
+// calls to return. Successive calls to Shutdown have no effect.
//
// Shutdown is the only Endpoint method that may be called concurrently with
// other methods on the same Endpoint.
@@ -152,28 +172,31 @@ const (
// The client is, by definition, initially active, so this must be 0.
csClientActive = 0
csServerActive = 1
+ csShutdown = 2
)
-// Connect designates ep as a client Endpoint and blocks until the peer
-// Endpoint has called Endpoint.RecvFirst().
+// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst().
//
-// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and
-// ep.SendLast() have never been called.
+// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(),
+// ep.SendRecv(), and ep.SendLast() have never been called.
func (ep *Endpoint) Connect() error {
- ep.activeState = csClientActive
- ep.inactiveState = csServerActive
- return ep.ctrlConnect()
+ err := ep.ctrlConnect()
+ if err == nil {
+ raceBecomeActive()
+ }
+ return err
}
// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
// returns the datagram length specified by that call.
//
-// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never
-// been called.
+// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and
+// ep.SendLast() have never been called.
func (ep *Endpoint) RecvFirst() (uint32, error) {
if err := ep.ctrlWaitFirst(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -200,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
// after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then
// they can only shoot themselves in the foot.
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlRoundTrip(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -222,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error {
panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
}
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlWakeLast(); err != nil {
return err
}
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
index edb6a8bef..8d88b845d 100644
--- a/pkg/flipcall/flipcall_example_test.go
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -38,12 +38,12 @@ func Example() {
panic(err)
}
var clientEP Endpoint
- if err := clientEP.Init(pwd); err != nil {
+ if err := clientEP.Init(ClientSide, pwd); err != nil {
panic(err)
}
defer clientEP.Destroy()
var serverEP Endpoint
- if err := serverEP.Init(pwd); err != nil {
+ if err := serverEP.Init(ServerSide, pwd); err != nil {
panic(err)
}
defer serverEP.Destroy()
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index da9d736ab..168a487ec 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -39,11 +39,11 @@ func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []Endpoi
c.pwa.Destroy()
tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
}
- if err := c.clientEP.Init(pwd, clientOpts...); err != nil {
+ if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil {
c.pwa.Destroy()
tb.Fatalf("failed to create client Endpoint: %v", err)
}
- if err := c.serverEP.Init(pwd, serverOpts...); err != nil {
+ if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil {
c.pwa.Destroy()
c.clientEP.Destroy()
tb.Fatalf("failed to create server Endpoint: %v", err)
@@ -62,17 +62,30 @@ func (c *testConnection) destroy() {
}
func testSendRecv(t *testing.T, c *testConnection) {
+ // This shared variable is used to confirm that synchronization between
+ // flipcall endpoints is visible to the Go race detector.
+ state := 0
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
t.Logf("server Endpoint waiting for packet 1")
if _, err := c.serverEP.RecvFirst(); err != nil {
- t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
+ }
+ state++
+ if state != 2 {
+ t.Errorf("shared state counter: got %d, wanted 2", state)
}
t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
if _, err := c.serverEP.SendRecv(0); err != nil {
- t.Fatalf("server Endpoint.SendRecv() failed: %v", err)
+ t.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
+ }
+ state++
+ if state != 4 {
+ t.Errorf("shared state counter: got %d, wanted 4", state)
}
t.Logf("server Endpoint got packet 3")
}()
@@ -87,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
+ state++
+ if state != 1 {
+ t.Errorf("shared state counter: got %d, wanted 1", state)
+ }
t.Logf("client Endpoint sending packet 1 and waiting for packet 2")
if _, err := c.clientEP.SendRecv(0); err != nil {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
+ state++
+ if state != 3 {
+ t.Errorf("shared state counter: got %d, wanted 3", state)
+ }
t.Logf("client Endpoint got packet 2, sending packet 3")
if err := c.clientEP.SendLast(0); err != nil {
t.Fatalf("client Endpoint.SendLast() failed: %v", err)
@@ -105,7 +126,30 @@ func TestSendRecv(t *testing.T) {
testSendRecv(t, c)
}
-func testShutdownConnect(t *testing.T, c *testConnection) {
+func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ if err := c.clientEP.Connect(); err == nil {
+ t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, false)
+}
+
+func TestShutdownBeforeConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, true)
+}
+
+func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var clientRun sync.WaitGroup
clientRun.Add(1)
go func() {
@@ -115,44 +159,86 @@ func testShutdownConnect(t *testing.T, c *testConnection) {
}
}()
time.Sleep(time.Second) // to allow c.clientEP.Connect() to block
- c.clientEP.Shutdown()
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
clientRun.Wait()
}
-func TestShutdownConnect(t *testing.T) {
+func TestShutdownDuringConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, false)
+}
+
+func TestShutdownDuringConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, true)
+}
+
+func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
+ if _, err := c.serverEP.RecvFirst(); err == nil {
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeRecvFirstLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeRecvFirst(t, c, false)
+}
+
+func TestShutdownBeforeRecvFirstRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownConnect(t, c)
+ testShutdownBeforeRecvFirst(t, c, true)
}
-func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) {
+func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
- _, err := c.serverEP.RecvFirst()
- if err == nil {
+ if _, err := c.serverEP.RecvFirst(); err == nil {
t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownRecvFirstBeforeConnect(t *testing.T) {
+func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstBeforeConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownRecvFirstBeforeConnect(t, c)
+ testShutdownDuringRecvFirstBeforeConnect(t, c, true)
}
-func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
+func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err == nil {
- t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
defer func() {
@@ -164,23 +250,75 @@ func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownRecvFirstAfterConnect(t *testing.T) {
+func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, true)
+}
+
+func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ }
+ // At this point, the client must be blocked in c.clientEP.SendRecv().
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ if _, err := c.clientEP.SendRecv(0); err == nil {
+ t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownDuringClientSendRecvLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringClientSendRecv(t, c, false)
+}
+
+func TestShutdownDuringClientSendRecvRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownRecvFirstAfterConnect(t, c)
+ testShutdownDuringClientSendRecv(t, c, true)
}
-func testShutdownSendRecv(t *testing.T, c *testConnection) {
+func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err != nil {
- t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
}
if _, err := c.serverEP.SendRecv(0); err == nil {
t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly")
@@ -199,14 +337,24 @@ func testShutdownSendRecv(t *testing.T, c *testConnection) {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
time.Sleep(time.Second) // to allow serverEP.SendRecv() to block
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownSendRecv(t *testing.T) {
+func TestShutdownDuringServerSendRecvLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownSendRecv(t, c)
+ testShutdownDuringServerSendRecv(t, c, false)
+}
+
+func TestShutdownDuringServerSendRecvRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringServerSendRecv(t, c, true)
}
func benchmarkSendRecv(b *testing.B, c *testConnection) {
@@ -218,15 +366,17 @@ func benchmarkSendRecv(b *testing.B, c *testConnection) {
return
}
if _, err := c.serverEP.RecvFirst(); err != nil {
- b.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ b.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
}
for i := 1; i < b.N; i++ {
if _, err := c.serverEP.SendRecv(0); err != nil {
- b.Fatalf("server Endpoint.SendRecv() failed: %v", err)
+ b.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
}
}
if err := c.serverEP.SendLast(0); err != nil {
- b.Fatalf("server Endpoint.SendLast() failed: %v", err)
+ b.Errorf("server Endpoint.SendLast() failed: %v", err)
}
}()
defer func() {
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index 7c8977893..a37952637 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -17,17 +17,19 @@ package flipcall
import (
"reflect"
"unsafe"
+
+ "gvisor.dev/gvisor/third_party/gvsync"
)
-// Packets consist of an 8-byte header followed by an arbitrarily-sized
+// Packets consist of a 16-byte header followed by an arbitrarily-sized
// datagram. The header consists of:
//
// - A 4-byte native-endian connection state.
//
// - A 4-byte native-endian datagram length in bytes.
+//
+// - 8 reserved bytes.
const (
- sizeofUint32 = unsafe.Sizeof(uint32(0))
-
// PacketHeaderBytes is the size of a flipcall packet header in bytes. The
// maximum datagram size supported by a flipcall connection is equal to the
// length of the packet window minus PacketHeaderBytes.
@@ -35,7 +37,7 @@ const (
// PacketHeaderBytes is exported to support its use in constant
// expressions. Non-constant expressions may prefer to use
// PacketWindowLengthForDataCap().
- PacketHeaderBytes = 2 * sizeofUint32
+ PacketHeaderBytes = 16
)
func (ep *Endpoint) connState() *uint32 {
@@ -43,7 +45,7 @@ func (ep *Endpoint) connState() *uint32 {
}
func (ep *Endpoint) dataLen() *uint32 {
- return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32))
+ return (*uint32)((unsafe.Pointer)(ep.packet + 4))
}
// Data returns the datagram part of ep's packet window as a byte slice.
@@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte {
bsReflect.Cap = int(ep.dataCap)
return bs
}
+
+// ioSync is a dummy variable used to indicate synchronization to the Go race
+// detector. Compare syscall.ioSync.
+var ioSync int64
+
+func raceBecomeActive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceAcquire((unsafe.Pointer)(&ioSync))
+ }
+}
+
+func raceBecomeInactive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ }
+}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
index e7dd812b3..b127a2bbb 100644
--- a/pkg/flipcall/futex_linux.go
+++ b/pkg/flipcall/futex_linux.go
@@ -59,7 +59,12 @@ func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeRespon
func (ep *Endpoint) futexSwitchToPeer() error {
// Update connection state to indicate that the peer should be active.
if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
- return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState()))
+ switch cs := atomic.LoadUint32(ep.connState()); cs {
+ case csShutdown:
+ return shutdownError{}
+ default:
+ return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
+ }
}
// Wake the peer's Endpoint.futexSwitchFromPeer().
@@ -75,16 +80,18 @@ func (ep *Endpoint) futexSwitchFromPeer() error {
case ep.activeState:
return nil
case ep.inactiveState:
- // Continue to FUTEX_WAIT.
+ if ep.isShutdownLocally() {
+ return shutdownError{}
+ }
+ if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
+ return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
+ }
+ continue
+ case csShutdown:
+ return shutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
- if ep.isShutdownLocally() {
- return shutdownError{}
- }
- if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
- return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
- }
}
}
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index 11716af81..0c5f50397 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
index e6a8dbd02..4b9321711 100644
--- a/pkg/gate/BUILD
+++ b/pkg/gate/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD
index 8f3defa25..34d2673ef 100644
--- a/pkg/ilist/BUILD
+++ b/pkg/ilist/BUILD
@@ -1,5 +1,6 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
index c8e923a74..a5d980d14 100644
--- a/pkg/linewriter/BUILD
+++ b/pkg/linewriter/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index 12615240c..fc5f5779b 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index 3b8a691f4..dd6ca6d39 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -1,5 +1,7 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -21,6 +23,12 @@ proto_library(
visibility = ["//:sandbox"],
)
+cc_proto_library(
+ name = "metric_cc_proto",
+ visibility = ["//:sandbox"],
+ deps = [":metric_proto"],
+)
+
go_proto_library(
name = "metric_go_proto",
importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto",
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index c6737bf97..f32244c69 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
@@ -19,11 +20,14 @@ go_library(
"pool.go",
"server.go",
"transport.go",
+ "transport_flipcall.go",
"version.go",
],
importpath = "gvisor.dev/gvisor/pkg/p9",
deps = [
"//pkg/fd",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 7dc20aeef..221516c6c 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -20,6 +20,8 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -77,6 +79,47 @@ type Client struct {
// fidPool is the collection of available fids.
fidPool pool
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
// pending is the set of pending messages.
pending map[Tag]*response
pendingMu sync.Mutex
@@ -89,25 +132,12 @@ type Client struct {
// Whoever writes to this channel is permitted to call recv. When
// finished calling recv, this channel should be emptied.
recvr chan bool
-
- // messageSize is the maximum total size of a message.
- messageSize uint32
-
- // payloadSize is the maximum payload size of a read or write
- // request. For large reads and writes this means that the
- // read or write is broken up into buffer-size/payloadSize
- // requests.
- payloadSize uint32
-
- // version is the agreed upon version X of 9P2000.L.Google.X.
- // version 0 implies 9P2000.L.
- version uint32
}
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
return nil, ErrBadVersionString
}
for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion)
+ err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
// The server told us to try again with a lower version.
if err == syscall.EAGAIN {
@@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.version = version
break
}
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacy
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacy
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
// handleOne handles a single incoming message.
//
// This should only be called with the token from recvr. Note that the received
@@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
-// sendRecv performs a roundtrip message exchange.
+// sendRecvLegacy performs a roundtrip message exchange.
//
// This is called by internal functions.
-func (c *Client) sendRecv(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) error {
tag, ok := c.tagPool.Get()
if !ok {
return ErrOutOfTags
@@ -296,12 +479,77 @@ func (c *Client) sendRecv(t message, r message) error {
return nil
}
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacy(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ // Map all transport errors to EIO, but ensure that the real error
+ // is logged.
+ log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err)
+ return syscall.EIO
+ }
+ }
+
+ // Send the request and receive the server's response.
+ rsz, err := ch.send(t)
+ if err != nil {
+ // See above.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err)
+ return syscall.EIO
+ }
+
+ // Parse the server's response.
+ _, retErr := ch.recv(r, rsz)
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return retErr
+}
+
// Version returns the negotiated 9P2000.L.Google version number.
func (c *Client) Version() uint32 {
return c.version
}
-// Close closes the underlying socket.
-func (c *Client) Close() error {
- return c.socket.Close()
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 87b2dd61e..29a0afadf 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -35,23 +35,23 @@ func TestVersion(t *testing.T) {
go s.Handle(serverSocket)
// NewClient does a Tversion exchange, so this is our test for success.
- c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString())
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
if err != nil {
t.Fatalf("got %v, expected nil", err)
}
// Check a bogus version string.
- if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a bogus version number.
- if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a too high version number.
- if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN {
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN {
t.Errorf("got %v expected %v", err, syscall.EAGAIN)
}
@@ -60,3 +60,45 @@ func TestVersion(t *testing.T) {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
}
+
+func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ // See above.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // See above.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // See above.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ b.Fatalf("got %v, expected nil", err)
+ }
+
+ // Initialize messages.
+ sendRecv := fn(c)
+ tversion := &Tversion{
+ Version: versionString(highestSupportedVersion),
+ MSize: DefaultMessageSize,
+ }
+ rversion := new(Rversion)
+
+ // Run in a loop.
+ for i := 0; i < b.N; i++ {
+ if err := sendRecv(tversion, rversion); err != nil {
+ b.Fatalf("got unexpected err: %v", err)
+ }
+ }
+}
+
+func BenchmarkSendRecvLegacy(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+}
+
+func BenchmarkSendRecvChannel(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel })
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 999b4f684..ba9a55d6d 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message {
ref.opened = true
ref.openFlags = t.Flags
- return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}
+ rlopen := &Rlopen{QID: qid, IoUnit: ioUnit}
+ rlopen.SetFilePayload(osFile)
+ return rlopen
}
func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
@@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
// Replace the FID reference.
cs.InsertFID(t.FID, newRef)
- return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil
+ rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}}
+ rlcreate.SetFilePayload(osFile)
+ return rlcreate, nil
}
// handle implements handler.handle.
@@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message {
return newErr(err)
}
- return &Rlconnect{File: osFile}
+ rlconnect := &Rlconnect{}
+ rlconnect.SetFilePayload(osFile)
+ return rlconnect
+}
+
+// handle implements handler.handle.
+func (t *Tchannel) handle(cs *connState) message {
+ // Ensure that channels are enabled.
+ if err := cs.initializeChannels(); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the given channel.
+ ch := cs.lookupChannel(t.ID)
+ if ch == nil {
+ return newErr(syscall.ENOSYS)
+ }
+
+ // Return the payload. Note that we need to duplicate the file
+ // descriptor for the channel allocator, because sending is a
+ // destructive operation between sendRecvLegacy (and now the newer
+ // channel send operations). Same goes for the client FD.
+ rchannel := &Rchannel{
+ Offset: uint64(ch.desc.Offset),
+ Length: uint64(ch.desc.Length),
+ }
+ switch t.Control {
+ case 0:
+ // Open the main data channel.
+ mfd, err := syscall.Dup(int(cs.channelAlloc.FD()))
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(mfd))
+ case 1:
+ cfd, err := syscall.Dup(ch.client.FD())
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(cfd))
+ default:
+ return newErr(syscall.EINVAL)
+ }
+ return rchannel
}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index fd9eb1c5d..ffdd7e8c6 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -64,6 +64,21 @@ type filer interface {
SetFilePayload(*fd.FD)
}
+// filePayload embeds a File object.
+type filePayload struct {
+ File *fd.FD
+}
+
+// FilePayload returns the file payload.
+func (f *filePayload) FilePayload() *fd.FD {
+ return f.File
+}
+
+// SetFilePayload sets the received file.
+func (f *filePayload) SetFilePayload(file *fd.FD) {
+ f.File = file
+}
+
// Tversion is a version request.
type Tversion struct {
// MSize is the message size to use.
@@ -524,10 +539,7 @@ type Rlopen struct {
// IoUnit is the recommended I/O unit.
IoUnit uint32
- // File may be attached via the socket.
- //
- // This is an extension specific to this package.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType {
return MsgRlopen
}
-// FilePayload returns the file payload.
-func (r *Rlopen) FilePayload() *fd.FD {
- return r.File
-}
-
-// SetFilePayload sets the received file.
-func (r *Rlopen) SetFilePayload(file *fd.FD) {
- r.File = file
-}
-
// String implements fmt.Stringer.
func (r *Rlopen) String() string {
return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
@@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string {
// Rlconnect is a connect response.
type Rlconnect struct {
- // File is a host socket.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType {
return MsgRlconnect
}
-// FilePayload returns the file payload.
-func (r *Rlconnect) FilePayload() *fd.FD {
- return r.File
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
}
-// SetFilePayload sets the received file.
-func (r *Rlconnect) SetFilePayload(file *fd.FD) {
- r.File = file
+// Tchannel creates a new channel.
+type Tchannel struct {
+ // ID is the channel ID.
+ ID uint32
+
+ // Control is 0 if the Rchannel response should provide the flipcall
+ // component of the channel, and 1 if the Rchannel response should
+ // provide the fdchannel component of the channel.
+ Control uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tchannel) Decode(b *buffer) {
+ t.ID = b.Read32()
+ t.Control = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tchannel) Encode(b *buffer) {
+ b.Write32(t.ID)
+ b.Write32(t.Control)
+}
+
+// Type implements message.Type.
+func (*Tchannel) Type() MsgType {
+ return MsgTchannel
}
// String implements fmt.Stringer.
-func (r *Rlconnect) String() string {
- return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+func (t *Tchannel) String() string {
+ return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control)
+}
+
+// Rchannel is the channel response.
+type Rchannel struct {
+ Offset uint64
+ Length uint64
+ filePayload
+}
+
+// Decode implements encoder.Decode.
+func (r *Rchannel) Decode(b *buffer) {
+ r.Offset = b.Read64()
+ r.Length = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rchannel) Encode(b *buffer) {
+ b.Write64(r.Offset)
+ b.Write64(r.Length)
+}
+
+// Type implements message.Type.
+func (*Rchannel) Type() MsgType {
+ return MsgRchannel
+}
+
+// String implements fmt.Stringer.
+func (r *Rchannel) String() string {
+ return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
}
const maxCacheSize = 3
@@ -2356,4 +2409,6 @@ func init() {
msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
+ msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index e12831dbd..25530adca 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -378,6 +378,8 @@ const (
MsgRlconnect = 137
MsgTallocate = 138
MsgRallocate = 139
+ MsgTchannel = 250
+ MsgRchannel = 251
)
// QIDType represents the file type for QIDs.
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 6e939a49a..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
package(licenses = ["notice"])
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 95846e5f7..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
@@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) {
}()
// Create the client.
- client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString())
+ client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString())
if err != nil {
serverSocket.Close()
clientSocket.Close()
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index b294efbb0..e717e6161 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -21,6 +21,9 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -45,7 +48,6 @@ type Server struct {
}
// NewServer returns a new server.
-//
func NewServer(attacher Attacher) *Server {
return &Server{
attacher: attacher,
@@ -85,6 +87,8 @@ type connState struct {
// version 0 implies 9P2000.L.
version uint32
+ // -- below relates to the legacy handler --
+
// recvOkay indicates that a receive may start.
recvOkay chan bool
@@ -93,6 +97,20 @@ type connState struct {
// sendDone is signalled when a send is finished.
sendDone chan error
+
+ // -- below relates to the flipcall handler --
+
+ // channelMu protects below.
+ channelMu sync.Mutex
+
+ // channelWg represents active workers.
+ channelWg sync.WaitGroup
+
+ // channelAlloc allocates channel memory.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ // channels are the set of initialized channels.
+ channels []*channel
}
// fidRef wraps a node and tracks references.
@@ -386,6 +404,101 @@ func (cs *connState) WaitTag(t Tag) {
<-ch
}
+// initializeChannels initializes all channels.
+//
+// This is a no-op if channels are already initialized.
+func (cs *connState) initializeChannels() (err error) {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+
+ // Initialize our channel allocator.
+ if cs.channelAlloc == nil {
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return err
+ }
+ cs.channelAlloc = alloc
+ }
+
+ // Create all the channels.
+ for len(cs.channels) < channelsPerClient {
+ res := &channel{
+ done: make(chan struct{}),
+ }
+
+ res.desc, err = cs.channelAlloc.Allocate(channelSize)
+ if err != nil {
+ return err
+ }
+ if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil {
+ return err
+ }
+
+ socks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ res.data.Destroy() // Cleanup.
+ return err
+ }
+ res.fds.Init(socks[0])
+ res.client = fd.New(socks[1])
+
+ cs.channels = append(cs.channels, res)
+
+ // Start servicing the channel.
+ //
+ // When we call stop, we will close all the channels and these
+ // routines should finish. We need the wait group to ensure
+ // that active handlers are actually finished before cleanup.
+ cs.channelWg.Add(1)
+ go func() { // S/R-SAFE: Server side.
+ defer cs.channelWg.Done()
+ if err := res.service(cs); err != nil {
+ log.Warningf("p9.channel.service: %v", err)
+ }
+ }()
+ }
+
+ return nil
+}
+
+// lookupChannel looks up the channel with given id.
+//
+// The function returns nil if no such channel is available.
+func (cs *connState) lookupChannel(id uint32) *channel {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+ if id >= uint32(len(cs.channels)) {
+ return nil
+ }
+ return cs.channels[id]
+}
+
+// handle handles a single message.
+func (cs *connState) handle(m message) (r message) {
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %s", debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ return
+}
+
// handleRequest handles a single request.
//
// The recvDone channel is signaled when recv is done (with a error if
@@ -428,41 +541,20 @@ func (cs *connState) handleRequest() {
}
// Handle the message.
- var r message // r is the response.
- defer func() {
- if r == nil {
- // Don't allow a panic to propagate.
- recover()
-
- // Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ r := cs.handle(m)
- // Wrap in an EFAULT error; we don't really have a
- // better way to describe this kind of error. It will
- // usually manifest as a result of the test framework.
- r = newErr(syscall.EFAULT)
- }
+ // Clear the tag before sending. That's because as soon as this hits
+ // the wire, the client can legally send the same tag.
+ cs.ClearTag(tag)
- // Clear the tag before sending. That's because as soon as this
- // hits the wire, the client can legally send another message
- // with the same tag.
- cs.ClearTag(tag)
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
- // Send back the result.
- cs.sendMu.Lock()
- err = send(cs.conn, tag, r)
- cs.sendMu.Unlock()
- cs.sendDone <- err
- }()
- if handler, ok := m.(handler); ok {
- // Call the message handler.
- r = handler.handle(cs)
- } else {
- // Produce an ENOSYS error.
- r = newErr(syscall.ENOSYS)
- }
+ // Return the message to the cache.
msgRegistry.put(m)
- m = nil // 'm' should not be touched after this point.
}
func (cs *connState) handleRequests() {
@@ -477,7 +569,27 @@ func (cs *connState) stop() {
close(cs.recvDone)
close(cs.sendDone)
- for _, fidRef := range cs.fids {
+ // Free the channels.
+ cs.channelMu.Lock()
+ for _, ch := range cs.channels {
+ ch.Shutdown()
+ }
+ cs.channelWg.Wait()
+ for _, ch := range cs.channels {
+ ch.Close()
+ }
+ cs.channels = nil // Clear.
+ cs.channelMu.Unlock()
+
+ // Free the channel memory.
+ if cs.channelAlloc != nil {
+ cs.channelAlloc.Destroy()
+ }
+
+ // Close all remaining fids.
+ for fid, fidRef := range cs.fids {
+ delete(cs.fids, fid)
+
// Drop final reference in the FID table. Note this should
// always close the file, since we've ensured that there are no
// handlers running via the wait for Pending => 0 below.
@@ -510,7 +622,7 @@ func (cs *connState) service() error {
for i := 0; i < pending; i++ {
<-cs.sendDone
}
- return err
+ return nil
}
// This handler is now pending.
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 5648df589..6e8b4bbcd 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -54,7 +54,10 @@ const (
headerLength uint32 = 7
// maximumLength is the largest possible message.
- maximumLength uint32 = 4 * 1024 * 1024
+ maximumLength uint32 = 1 << 20
+
+ // DefaultMessageSize is a sensible default.
+ DefaultMessageSize uint32 = 64 << 10
// initialBufferLength is the initial data buffer we allocate.
initialBufferLength uint32 = 64
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
new file mode 100644
index 000000000..233f825e3
--- /dev/null
+++ b/pkg/p9/transport_flipcall.go
@@ -0,0 +1,243 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// channelsPerClient is the number of channels to create per client.
+//
+// While the client and server will generally agree on this number, in reality
+// it's completely up to the server. We simply define a minimum of 2, and a
+// maximum of 4, and select the number of available processes as a tie-breaker.
+// Note that we don't want the number of channels to be too large, because each
+// will account for channelSize memory used, which can be large.
+var channelsPerClient = func() int {
+ n := runtime.NumCPU()
+ if n < 2 {
+ return 2
+ }
+ if n > 4 {
+ return 4
+ }
+ return n
+}()
+
+// channelSize is the channel size to create.
+//
+// We simply ensure that this is larger than the largest possible message size,
+// plus the flipcall packet header, plus the two bytes we write below.
+const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
+
+// channel is a fast IPC channel.
+//
+// The same object is used by both the server and client implementations. In
+// general, the client will use only the send and recv methods.
+type channel struct {
+ desc flipcall.PacketWindowDescriptor
+ data flipcall.Endpoint
+ fds fdchannel.Endpoint
+ buf buffer
+
+ // -- client only --
+ connected bool
+ active bool
+
+ // -- server only --
+ client *fd.FD
+ done chan struct{}
+}
+
+// reset resets the channel buffer.
+func (ch *channel) reset(sz uint32) {
+ ch.buf.data = ch.data.Data()[:sz]
+}
+
+// service services the channel.
+func (ch *channel) service(cs *connState) error {
+ rsz, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rsz > 0 {
+ m, err := ch.recv(nil, rsz)
+ if err != nil {
+ return err
+ }
+ r := cs.handle(m)
+ msgRegistry.put(m)
+ rsz, err = ch.send(r)
+ if err != nil {
+ return err
+ }
+ }
+ return nil // Done.
+}
+
+// Shutdown shuts down the channel.
+//
+// This must be called before Close.
+func (ch *channel) Shutdown() {
+ ch.data.Shutdown()
+}
+
+// Close closes the channel.
+//
+// This must only be called once, and cannot return an error. Note that
+// synchronization for this method is provided at a high-level, depending on
+// whether it is the client or server. This cannot be called while there are
+// active callers in either service or sendRecv.
+//
+// Precondition: the channel should be shutdown.
+func (ch *channel) Close() error {
+ // Close all backing transports.
+ ch.fds.Destroy()
+ ch.data.Destroy()
+ if ch.client != nil {
+ ch.client.Close()
+ }
+ return nil
+}
+
+// send sends the given message.
+//
+// The return value is the size of the received response. Not that in the
+// server case, this is the size of the next request.
+func (ch *channel) send(m message) (uint32, error) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [channel @%p] %s", ch, m.String())
+ }
+
+ // Send any file payload.
+ sentFD := false
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ if err := ch.fds.SendFD(f.FD()); err != nil {
+ return 0, err
+ }
+ f.Close() // Per sendRecvLegacy.
+ sentFD = true // To mark below.
+ }
+ }
+
+ // Encode the message.
+ //
+ // Note that IPC itself encodes the length of messages, so we don't
+ // need to encode a standard 9P header. We write only the message type.
+ ch.reset(0)
+
+ ch.buf.WriteMsgType(m.Type())
+ if sentFD {
+ ch.buf.Write8(1) // Incoming FD.
+ } else {
+ ch.buf.Write8(0) // No incoming FD.
+ }
+ m.Encode(&ch.buf)
+ ssz := uint32(len(ch.buf.data)) // Updated below.
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ copy(ch.data.Data()[ssz:], p)
+ ssz += uint32(len(p))
+ }
+
+ // Perform the one-shot communication.
+ return ch.data.SendRecv(ssz)
+}
+
+// recv decodes a message that exists on the channel.
+//
+// If the passed r is non-nil, then the type must match or an error will be
+// generated. If the passed r is nil, then a new message will be created and
+// returned.
+func (ch *channel) recv(r message, rsz uint32) (message, error) {
+ // Decode the response from the inline buffer.
+ ch.reset(rsz)
+ t := ch.buf.ReadMsgType()
+ hasFD := ch.buf.Read8() != 0
+ if t == MsgRlerror {
+ // Change the message type. We check for this special case
+ // after decoding below, and transform into an error.
+ r = &Rlerror{}
+ } else if r == nil {
+ nr, err := msgRegistry.get(0, t)
+ if err != nil {
+ return nil, err
+ }
+ r = nr // New message.
+ } else if t != r.Type() {
+ // Not an error and not the expected response; propagate.
+ return nil, &ErrBadResponse{Got: t, Want: r.Type()}
+ }
+
+ // Is there a payload? Copy from the latter portion.
+ if payloader, ok := r.(payloader); ok {
+ fs := payloader.FixedSize()
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
+ ch.buf.data = ch.buf.data[:fs]
+ }
+
+ r.Decode(&ch.buf)
+ if ch.buf.isOverrun() {
+ // Nothing valid was available.
+ log.Debugf("recv [got %d bytes, needed more]", rsz)
+ return nil, ErrNoValidMessage
+ }
+
+ // Read any FD result.
+ if hasFD {
+ if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
+ f := fd.New(rfd)
+ if filer, ok := r.(filer); ok {
+ // Set the payload.
+ filer.SetFilePayload(f)
+ } else {
+ // Don't want the FD.
+ f.Close()
+ }
+ } else {
+ // The header bit was set but nothing came in.
+ log.Warningf("expected FD, got err: %v", err)
+ }
+ }
+
+ // Log a message.
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [channel @%p] %s", ch, r.String())
+ }
+
+ // Convert errors appropriately; see above.
+ if rlerr, ok := r.(*Rlerror); ok {
+ return nil, syscall.Errno(rlerr.Error)
+ }
+
+ return r, nil
+}
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index cdb3bc841..2f50ff3ea 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) {
t.Fatalf("unable to create file: %v", err)
}
- if err := send(client, Tag(1), &Rlopen{File: f}); err != nil {
+ rlopen := &Rlopen{}
+ rlopen.SetFilePayload(f)
+ if err := send(client, Tag(1), rlopen); err != nil {
t.Fatalf("send got err %v expected nil", err)
}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index c2a2885ae..f1ffdd23a 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 7
+ highestSupportedVersion uint32 = 8
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool {
func versionSupportsTallocate(v uint32) bool {
return v >= 7
}
+
+// versionSupportsFlipcall returns true if version v supports IPC channels from
+// the flipcall package. Note that these must be negotiated, but this version
+// string indicates that such a facility exists.
+func versionSupportsFlipcall(v uint32) bool {
+ return v >= 8
+}
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
index 697e7a2f4..078f084b2 100644
--- a/pkg/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 9c08452fc..827385139 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "weak_ref_list",
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 828e9b5c1..ad69e0757 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -215,8 +215,8 @@ type AtomicRefCount struct {
type LeakMode uint32
const (
- // uninitializedLeakChecking indicates that the leak checker has not yet been initialized.
- uninitializedLeakChecking LeakMode = iota
+ // UninitializedLeakChecking indicates that the leak checker has not yet been initialized.
+ UninitializedLeakChecking LeakMode = iota
// NoLeakChecking indicates that no effort should be made to check for
// leaks.
@@ -318,7 +318,7 @@ func (r *AtomicRefCount) finalize() {
switch LeakMode(atomic.LoadUint32(&leakMode)) {
case NoLeakChecking:
return
- case uninitializedLeakChecking:
+ case UninitializedLeakChecking:
note = "(Leak checker uninitialized): "
}
if n := r.ReadRefs(); n != 0 {
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index d1024e49d..af94e944d 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
index 0a3d92854..be328db12 100644
--- a/pkg/seccomp/seccomp_unsafe.go
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -35,7 +35,7 @@ type sockFprog struct {
//go:nosplit
func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
// PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 {
return errno
}
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD
index f38fb39f3..22abdc69f 100644
--- a/pkg/secio/BUILD
+++ b/pkg/secio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index 694486296..12d7c77d2 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:private"],
diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD
index 53989301f..2d6379c86 100644
--- a/pkg/sentry/BUILD
+++ b/pkg/sentry/BUILD
@@ -8,5 +8,7 @@ package_group(
packages = [
"//pkg/sentry/...",
"//runsc/...",
+ # Code generated by go_marshal relies on go_marshal libraries.
+ "//tools/go_marshal/...",
],
)
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 7aace2d7b..c71cff9f3 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -42,6 +43,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "registers_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":registers_proto"],
+)
+
go_proto_library(
name = "registers_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto",
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index bf802d1b6..5522cecd0 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 7e8918722..0c86197f7 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "device",
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index d7259b47b..3119a61b6 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fs",
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index fbca06761..3cb73bd78 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
// Remove removes the given file or symlink. The root dirent is used to
// resolve name, and must not be nil.
-func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
+func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error {
// Check the root.
if root == nil {
panic("Dirent.Remove: root must not be nil")
@@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
// Remove cannot remove directories.
if IsDir(child.Inode.StableAttr) {
return syscall.EISDIR
+ } else if dirPath {
+ return syscall.ENOTDIR
}
// Remove cannot remove a mount point.
diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go
index 884e3ff06..47bc72a88 100644
--- a/pkg/sentry/fs/dirent_refs_test.go
+++ b/pkg/sentry/fs/dirent_refs_test.go
@@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) {
}
d := f.Dirent
- if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil {
+ if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil {
t.Fatalf("root.Remove(root, %q) failed: %v", name, err)
}
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index bf00b9c09..b9bd9ed17 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fdpipe",
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index bb8117f89..c0a6e884b 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -515,6 +515,11 @@ type lockedReader struct {
// File is the file to read from.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Read, not ReadAt.
+ Offset int64
}
// Read implements io.Reader.Read.
@@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) {
if r.Ctx.Interrupted() {
return 0, syserror.ErrInterrupted
}
- n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset)
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset)
+ r.Offset += n
return int(n), err
}
@@ -544,11 +550,21 @@ type lockedWriter struct {
// File is the file to write to.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Write, not WriteAt.
+ Offset int64
}
// Write implements io.Writer.Write.
func (w *lockedWriter) Write(buf []byte) (int, error) {
- return w.WriteAt(buf, w.File.offset)
+ if w.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := w.WriteAt(buf, w.Offset)
+ w.Offset += int64(n)
+ return int(n), err
}
// WriteAt implements io.Writer.WriteAt.
@@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
// io.Copy, since our own Write interface does not have this same
// contract. Enforce that here.
for written < len(buf) {
+ if w.Ctx.Interrupted() {
+ return written, syserror.ErrInterrupted
+ }
var n int64
n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written))
if n > 0 {
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index d86f5bf45..b88303f17 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -15,6 +15,8 @@
package fs
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -105,8 +107,11 @@ type FileOperations interface {
// on the destination, following by a buffered copy with standard Read
// and Write operations.
//
+ // If dup is set, the data should be duplicated into the destination
+ // and retained.
+ //
// The same preconditions as Read apply.
- WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error)
+ WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error)
// Write writes src to file at offset and returns the number of bytes
// written which must be greater than or equal to 0. Like Read, file
@@ -126,7 +131,7 @@ type FileOperations interface {
// source. See WriteTo for details regarding how this is called.
//
// The same preconditions as Write apply; FileFlags.Write must be set.
- ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error)
+ ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error)
// Fsync writes buffered modifications of file and/or flushes in-flight
// operations to backing storage based on syncType. The range to sync is
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 9820f0b13..225e40186 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/refs"
@@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme
}
// WriteTo implements FileOperations.WriteTo.
-func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) {
err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
- n, err = ops.WriteTo(ctx, file, dst, opts)
+ n, err = ops.WriteTo(ctx, file, dst, count, dup)
return err // Will overwrite itself.
})
return
@@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm
}
// ReadFrom implements FileOperations.ReadFrom.
-func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) {
// See above; f.upper must be non-nil.
- return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts)
+ return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count)
}
// Fsync implements FileOperations.Fsync.
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 6499f87ac..b4ac83dc4 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "dirty_set_impl",
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index 626b9126a..fc5b3b1a1 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -15,6 +15,8 @@
package fsutil
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu
type FileNoSplice struct{}
// WriteTo implements fs.FileOperations.WriteTo.
-func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index d2495cb83..693625ddc 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -144,7 +144,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
mask := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: newSize}
- if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil {
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
return err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index e70bc28fb..dd80757dc 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -66,10 +66,8 @@ type CachingInodeOperations struct {
// mfp is used to allocate memory that caches backingFile's contents.
mfp pgalloc.MemoryFileProvider
- // forcePageCache indicates the sentry page cache should be used regardless
- // of whether the platform supports host mapped I/O or not. This must not be
- // modified after inode creation.
- forcePageCache bool
+ // opts contains options. opts is immutable.
+ opts CachingInodeOperationsOptions
attrMu sync.Mutex `state:"nosave"`
@@ -116,6 +114,20 @@ type CachingInodeOperations struct {
refs frameRefSet
}
+// CachingInodeOperationsOptions configures a CachingInodeOperations.
+//
+// +stateify savable
+type CachingInodeOperationsOptions struct {
+ // If ForcePageCache is true, use the sentry page cache even if a host file
+ // descriptor is available.
+ ForcePageCache bool
+
+ // If LimitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host file descriptor mappings returned by
+ // CachingInodeOperations.Translate().
+ LimitHostFDTranslation bool
+}
+
// CachedFileObject is a file that may require caching.
type CachedFileObject interface {
// ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts,
@@ -128,12 +140,16 @@ type CachedFileObject interface {
// WriteFromBlocksAt may return a partial write without an error.
WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)
- // SetMaskedAttributes sets the attributes in attr that are true in mask
- // on the backing file.
+ // SetMaskedAttributes sets the attributes in attr that are true in
+ // mask on the backing file. If the mask contains only ATime or MTime
+ // and the CachedFileObject has an FD to the file, then this operation
+ // is a noop unless forceSetTimestamps is true. This avoids an extra
+ // RPC to the gofer in the open-read/write-close case, when the
+ // timestamps on the file will be updated by the host kernel for us.
//
// SetMaskedAttributes may be called at any point, regardless of whether
// the file was opened.
- SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error
+ SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error
// Allocate allows the caller to reserve disk space for the inode.
// It's equivalent to fallocate(2) with 'mode=0'.
@@ -159,7 +175,7 @@ type CachedFileObject interface {
// NewCachingInodeOperations returns a new CachingInodeOperations backed by
// a CachedFileObject and its initial unstable attributes.
-func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, forcePageCache bool) *CachingInodeOperations {
+func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, opts CachingInodeOperationsOptions) *CachingInodeOperations {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
if mfp == nil {
panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
@@ -167,7 +183,7 @@ func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject
return &CachingInodeOperations{
backingFile: backingFile,
mfp: mfp,
- forcePageCache: forcePageCache,
+ opts: opts,
attr: uattr,
hostFileMapper: NewHostFileMapper(),
}
@@ -212,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Perms: true}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil {
return false
}
c.attr.Perms = perms
@@ -234,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode,
UID: owner.UID.Ok(),
GID: owner.GID.Ok(),
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil {
return err
}
if owner.UID.Ok() {
@@ -270,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In
AccessTime: !ts.ATimeOmit,
ModificationTime: !ts.MTimeOmit,
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil {
+ // Call SetMaskedAttributes with forceSetTimestamps = true to make sure
+ // the timestamp is updated.
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil {
return err
}
if !ts.ATimeOmit {
@@ -293,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode,
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: size}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
c.dataMu.Unlock()
return err
}
@@ -382,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode)
c.dirtyAttr.Size = false
// Write out cached attributes.
- if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
c.attrMu.Unlock()
return err
}
@@ -763,7 +781,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
// and memory mappings, and false if c.cache may contain data cached from
// c.backingFile.
func (c *CachingInodeOperations) useHostPageCache() bool {
- return !c.forcePageCache && c.backingFile.FD() >= 0
+ return !c.opts.ForcePageCache && c.backingFile.FD() >= 0
}
// AddMapping implements memmap.Mappable.AddMapping.
@@ -784,11 +802,6 @@ func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.Mappi
mf.MarkUnevictable(c, pgalloc.EvictableRange{r.Start, r.End})
}
}
- if c.useHostPageCache() && !usage.IncrementalMappedAccounting {
- for _, r := range mapped {
- usage.MemoryAccounting.Inc(r.Length(), usage.Mapped)
- }
- }
c.mapsMu.Unlock()
return nil
}
@@ -802,11 +815,6 @@ func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Ma
c.hostFileMapper.DecRefOn(r)
}
if c.useHostPageCache() {
- if !usage.IncrementalMappedAccounting {
- for _, r := range unmapped {
- usage.MemoryAccounting.Dec(r.Length(), usage.Mapped)
- }
- }
c.mapsMu.Unlock()
return
}
@@ -835,11 +843,15 @@ func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.Mapp
func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
// Hot path. Avoid defer.
if c.useHostPageCache() {
+ mr := optional
+ if c.opts.LimitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
return []memmap.Translation{
{
- Source: optional,
+ Source: mr,
File: c,
- Offset: optional.Start,
+ Offset: mr.Start,
Perms: usermem.AnyAccess,
},
}, nil
@@ -985,9 +997,7 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
seg, gap = seg.NextNonEmpty()
case gap.Ok() && gap.Start() < fr.End:
newRange := gap.Range().Intersect(fr)
- if usage.IncrementalMappedAccounting {
- usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
- }
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
seg, gap = c.refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
default:
c.refs.MergeAdjacent(fr)
@@ -1008,9 +1018,7 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
for seg.Ok() && seg.Start() < fr.End {
seg = c.refs.Isolate(seg, fr)
if old := seg.Value(); old == 1 {
- if usage.IncrementalMappedAccounting {
- usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
- }
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
seg = c.refs.Remove(seg).NextSegment()
} else {
seg.SetValue(old - 1)
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index dc19255ed..129f314c8 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block
return srcs.NumBytes(), nil
}
-func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -61,7 +61,7 @@ func TestSetPermissions(t *testing.T) {
uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
Perms: fs.FilePermsFromMode(0444),
})
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
perms := fs.FilePermsFromMode(0777)
@@ -150,7 +150,7 @@ func TestSetTimestamps(t *testing.T) {
ModificationTime: epoch,
StatusChangeTime: epoch,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil {
@@ -188,7 +188,7 @@ func TestTruncate(t *testing.T) {
uattr := fs.UnstableAttr{
Size: 0,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.Truncate(ctx, nil, uattr.Size); err != nil {
@@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B
return w.WriteFromBlocks(srcs)
}
-func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -280,7 +280,7 @@ func TestRead(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
@@ -336,7 +336,7 @@ func TestWrite(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 6b993928c..2b71ca0e1 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "gofer",
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
index 69999dc28..8f8ab5d29 100644
--- a/pkg/sentry/fs/gofer/fs.go
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -54,6 +54,10 @@ const (
// sandbox using files backed by the gofer. If set to false, unix sockets
// cannot be bound to gofer files without an overlay on top.
privateUnixSocketKey = "privateunixsocket"
+
+ // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to
+ // true.
+ limitHostFDTranslationKey = "limit_host_fd_translation"
)
// defaultAname is the default attach name.
@@ -134,12 +138,13 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// opts are parsed 9p mount options.
type opts struct {
- fd int
- aname string
- policy cachePolicy
- msize uint32
- version string
- privateunixsocket bool
+ fd int
+ aname string
+ policy cachePolicy
+ msize uint32
+ version string
+ privateunixsocket bool
+ limitHostFDTranslation bool
}
// options parses mount(2) data into structured options.
@@ -237,6 +242,11 @@ func options(data string) (opts, error) {
delete(options, privateUnixSocketKey)
}
+ if _, ok := options[limitHostFDTranslationKey]; ok {
+ o.limitHostFDTranslation = true
+ delete(options, limitHostFDTranslationKey)
+ }
+
// Fail to attach if the caller wanted us to do something that we
// don't support.
if len(options) > 0 {
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 95b064aea..d918d6620 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -215,8 +215,8 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
- if i.skipSetAttr(mask) {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error {
+ if i.skipSetAttr(mask, forceSetTimestamps) {
return nil
}
as, ans := attr.AccessTime.Unix()
@@ -251,13 +251,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa
// when:
// - Mask is empty
// - Mask contains only attributes that cannot be set in the gofer
-// - Mask contains only atime and/or mtime, and host FD exists
+// - forceSetTimestamps is false and mask contains only atime and/or mtime
+// and host FD exists
//
// Updates to atime and mtime can be skipped because cached value will be
// "close enough" to host value, given that operation went directly to host FD.
// Skipping atime updates is particularly important to reduce the number of
// operations sent to the Gofer for readonly files.
-func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
+func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool {
// First remove attributes that cannot be updated.
cpy := mask
cpy.Type = false
@@ -277,6 +278,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
return false
}
+ // If forceSetTimestamps was passed, then we cannot skip.
+ if forceSetTimestamps {
+ return false
+ }
+
+ // Skip if we have a host FD.
i.handlesMu.RLock()
defer i.handlesMu.RUnlock()
return (i.readHandles != nil && i.readHandles.Host != nil) ||
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 69d08a627..50da865c1 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -117,6 +117,11 @@ type session struct {
// Flags provided to the mount.
superBlockFlags fs.MountSourceFlags `state:"wait"`
+ // limitHostFDTranslation is the value used for
+ // CachingInodeOperationsOptions.LimitHostFDTranslation for all
+ // CachingInodeOperations created by the session.
+ limitHostFDTranslation bool
+
// connID is a unique identifier for the session connection.
connID string `state:"wait"`
@@ -218,8 +223,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p
uattr := unstable(ctx, valid, attr, s.mounter, s.client)
return sattr, &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, s.superBlockFlags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: s.superBlockFlags.ForcePageCache,
+ LimitHostFDTranslation: s.limitHostFDTranslation,
+ }),
}
}
@@ -242,13 +250,14 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
// Construct the session.
s := session{
- connID: dev,
- msize: o.msize,
- version: o.version,
- cachePolicy: o.policy,
- aname: o.aname,
- superBlockFlags: superBlockFlags,
- mounter: mounter,
+ connID: dev,
+ msize: o.msize,
+ version: o.version,
+ cachePolicy: o.policy,
+ aname: o.aname,
+ superBlockFlags: superBlockFlags,
+ limitHostFDTranslation: o.limitHostFDTranslation,
+ mounter: mounter,
}
s.EnableLeakCheck("gofer.session")
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index b1080fb1a..3e532332e 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "host",
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 679d8321a..a6e4a09e3 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error {
if mask.Empty() {
return nil
}
@@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
return unstableAttr(i.mops, &s), nil
}
-// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+// Allocate implements fsutil.CachedFileObject.Allocate.
func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error {
return syscall.Fallocate(i.FD(), 0, offset, length)
}
@@ -200,8 +200,10 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool,
// Build the fs.InodeOperations.
uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
iops := &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, msrc.Flags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: msrc.Flags.ForcePageCache,
+ }),
}
// Return the fs.Inode.
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 2526412a4..90331e3b2 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -43,12 +43,15 @@ type TTYFileOperations struct {
// fgProcessGroup is the foreground process group that is currently
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+
+ termios linux.KernelTermios
}
// newTTYFile returns a new fs.File that wraps a TTY FD.
func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
fileOperations: fileOperations{iops: iops},
+ termios: linux.DefaultSlaveTermios,
})
}
@@ -97,9 +100,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme
t.mu.Lock()
defer t.mu.Unlock()
- // Are we allowed to do the write?
- if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
- return 0, err
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
}
return t.fileOperations.Write(ctx, file, src, offset)
}
@@ -144,6 +150,9 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO
return 0, err
}
err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
return 0, err
case linux.TIOCGPGRP:
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 246b97161..5a388dad1 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -15,6 +15,7 @@
package fs
import (
+ "fmt"
"strings"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -207,6 +208,11 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
}
func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) {
+ // Sanity check.
+ if parent.Inode.overlay == nil {
+ panic(fmt.Sprintf("overlayCreate called with non-overlay parent inode (parent InodeOperations type is %T)", parent.Inode.InodeOperations))
+ }
+
// Dirent.Create takes renameMu if the Inode is an overlay Inode.
if err := copyUpLockedForRename(ctx, parent); err != nil {
return nil, err
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index c7f4e2d13..ba3e0233d 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"sync/atomic"
@@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
}
// WriteTo implements FileOperations.WriteTo.
-func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
@@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
}
// ReadFrom implements FileOperations.ReadFrom.
-func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 08d7c0c57..5a7a5b8cd 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "lock_range",
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 9b713e785..ac0398bd9 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -171,8 +171,6 @@ type MountNamespace struct {
// NewMountNamespace returns a new MountNamespace, with the provided node at the
// root, and the given cache size. A root must always be provided.
func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) {
- creds := auth.CredentialsFromContext(ctx)
-
// Set the root dirent and id on the root mount. The reference returned from
// NewDirent will be donated to the MountNamespace constructed below.
d := NewDirent(ctx, root, "/")
@@ -181,6 +179,7 @@ func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error
d: newRootMount(1, d),
}
+ creds := auth.CredentialsFromContext(ctx)
mns := MountNamespace{
userns: creds.UserNamespace,
root: d,
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 70ed854a8..1c93e8886 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "proc",
@@ -31,7 +33,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/log",
"//pkg/sentry/context",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 9adb23608..f70239449 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -17,10 +17,10 @@ package proc
import (
"bytes"
"fmt"
+ "io"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -28,9 +28,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
// newNet creates a new proc net entry.
@@ -57,15 +59,14 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo
"ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")),
"route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")),
"tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc),
- "udp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops")),
-
- "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
+ "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc),
+ "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
}
if s.SupportsIPv6() {
contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc)
contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte(""))
- contents["tcp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode"))
+ contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc)
contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode"))
}
}
@@ -216,7 +217,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
for _, se := range n.k.ListSockets() {
s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
continue
}
sfile := s.(*fs.File)
@@ -297,20 +298,66 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
return data, 0
}
-// netTCP implements seqfile.SeqSource for /proc/net/tcp.
-//
-// +stateify savable
-type netTCP struct {
- k *kernel.Kernel
+func networkToHost16(n uint16) uint16 {
+ // n is in network byte order, so is big-endian. The most-significant byte
+ // should be stored in the lower address.
+ //
+ // We manually inline binary.BigEndian.Uint16() because Go does not support
+ // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to
+ // binary.BigEndian.Uint16() require a read of binary.BigEndian and an
+ // interface method call, defeating inlining.
+ buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
+ return usermem.ByteOrder.Uint16(buf[:])
}
-// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
-func (*netTCP) NeedsUpdate(generation int64) bool {
- return true
+func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
+ switch family {
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if i != nil {
+ a = *i.(*linux.SockAddrInet)
+ }
+
+ // linux.SockAddrInet.Port is stored in the network byte order and is
+ // printed like a number in host byte order. Note that all numbers in host
+ // byte order are printed with the most-significant byte first when
+ // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux.
+ port := networkToHost16(a.Port)
+
+ // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order
+ // (i.e. most-significant byte in index 0). Linux represents this as a
+ // __be32 which is a typedef for an unsigned int, and is printed with
+ // %X. This means that for a little-endian machine, Linux prints the
+ // least-significant byte of the address first. To emulate this, we first
+ // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // which makes it have the equivalent encoding to a __be32 on a little
+ // endian machine. Note that this operation is a no-op on a big endian
+ // machine. Then similar to Linux, we format it with %X, which will print
+ // the most-significant byte of the __be32 address first, which is now
+ // actually the least-significant byte of the original address in
+ // linux.SockAddrInet.Addr on little endian machines, due to the conversion.
+ addr := usermem.ByteOrder.Uint32(a.Addr[:])
+
+ fmt.Fprintf(w, "%08X:%04X ", addr, port)
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if i != nil {
+ a = *i.(*linux.SockAddrInet6)
+ }
+
+ port := networkToHost16(a.Port)
+ addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4])
+ addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8])
+ addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12])
+ addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16])
+ fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port)
+ }
}
-// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
-func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
t := kernel.TaskFromContext(ctx)
if h != nil {
@@ -318,10 +365,10 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
}
var buf bytes.Buffer
- for _, se := range n.k.ListSockets() {
+ for _, se := range k.ListSockets() {
s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %+v in socket table, racing with destruction?", se.Sock)
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
continue
}
sfile := s.(*fs.File)
@@ -329,7 +376,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
if !ok {
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
}
- if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) {
+ if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
s.DecRef()
// Not tcp4 sockets.
continue
@@ -343,27 +390,23 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
// Field: sl; entry number.
fmt.Fprintf(&buf, "%4d: ", se.ID)
- portBuf := make([]byte, 2)
-
// Field: local_adddress.
- var localAddr linux.SockAddrInet
- if local, _, err := sops.GetSockName(t); err == nil {
- localAddr = *local.(*linux.SockAddrInet)
+ var localAddr linux.SockAddr
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = local
+ }
}
- binary.LittleEndian.PutUint16(portBuf, localAddr.Port)
- fmt.Fprintf(&buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(localAddr.Addr[:]),
- portBuf)
+ writeInetAddr(&buf, fa, localAddr)
// Field: rem_address.
- var remoteAddr linux.SockAddrInet
- if remote, _, err := sops.GetPeerName(t); err == nil {
- remoteAddr = *remote.(*linux.SockAddrInet)
+ var remoteAddr linux.SockAddr
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = remote
+ }
}
- binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port)
- fmt.Fprintf(&buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(remoteAddr.Addr[:]),
- portBuf)
+ writeInetAddr(&buf, fa, remoteAddr)
// Field: state; socket state.
fmt.Fprintf(&buf, "%02X ", sops.State())
@@ -386,7 +429,8 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
fmt.Fprintf(&buf, "%5d ", 0)
} else {
- fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout; number of unanswered 0-window probes.
@@ -428,7 +472,165 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
data := []seqfile.SeqData{
{
- Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n"),
+ Buf: header,
+ Handle: n,
+ },
+ {
+ Buf: buf.Bytes(),
+ Handle: n,
+ },
+ }
+ return data, 0
+}
+
+// netTCP implements seqfile.SeqSource for /proc/net/tcp.
+//
+// +stateify savable
+type netTCP struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netTCP) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n")
+ return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header)
+}
+
+// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6.
+//
+// +stateify savable
+type netTCP6 struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netTCP6) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")
+ return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header)
+}
+
+// netUDP implements seqfile.SeqSource for /proc/net/udp.
+//
+// +stateify savable
+type netUDP struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netUDP) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ continue
+ }
+ sfile := s.(*fs.File)
+ sops, ok := sfile.FileOperations.(socket.Socket)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ }
+ if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
+ s.DecRef()
+ // Not udp4 socket.
+ continue
+ }
+
+ // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock().
+
+ // Field: sl; entry number.
+ fmt.Fprintf(&buf, "%5d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddrInet
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, linux.AF_INET, &localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddrInet
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, linux.AF_INET, &remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(&buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(&buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%08X ", 0)
+
+ // Field: uid.
+ uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ fmt.Fprintf(&buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%8d ", 0)
+
+ // Field: inode.
+ fmt.Fprintf(&buf, "%8d ", sfile.InodeID())
+
+ // Field: ref; reference count on the socket inode. Don't count the ref
+ // we obtain while deferencing the weakref to this socket.
+ fmt.Fprintf(&buf, "%d ", sfile.ReadRefs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(&buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: drops; number of dropped packets. Unimplemented.
+ fmt.Fprintf(&buf, "%d", 0)
+
+ fmt.Fprintf(&buf, "\n")
+
+ s.DecRef()
+ }
+
+ data := []seqfile.SeqData{
+ {
+ Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops \n"),
Handle: n,
},
{
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index 0ef13f2f5..56e92721e 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -230,7 +230,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent
// But for whatever crazy reason, you can still walk to the given node.
for _, tg := range rpf.iops.pidns.ThreadGroups() {
if leader := tg.Leader(); leader != nil {
- name := strconv.FormatUint(uint64(tg.ID()), 10)
+ name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10)
m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice)
names = append(names, name)
}
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index 20c3eefc8..76433c7d0 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "seqfile",
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index 516efcc4c..d0f351e5a 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "ramfs",
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index eed1c2854..311798811 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -18,7 +18,6 @@ import (
"io"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/secio"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
// Check whether or not the objects being sliced are stream-oriented
- // (i.e. pipes or sockets). If yes, we elide checks and offset locks.
- srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr)
- dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr)
+ // (i.e. pipes or sockets). For all stream-oriented files and files
+ // where a specific offiset is not request, we acquire the file mutex.
+ // This has two important side effects. First, it provides the standard
+ // protection against concurrent writes that would mutate the offset.
+ // Second, it prevents Splice deadlocks. Only internal anonymous files
+ // implement the ReadFrom and WriteTo methods directly, and since such
+ // anonymous files are referred to by a unique fs.File object, we know
+ // that the file mutex takes strict precedence over internal locks.
+ // Since we enforce lock ordering here, we can't deadlock by using
+ // using a file in two different splice operations simultaneously.
+ srcPipe := !IsRegular(src.Dirent.Inode.StableAttr)
+ dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr)
+ dstAppend := !dstPipe && dst.Flags().Append
+ srcLock := srcPipe || !opts.SrcOffset
+ dstLock := dstPipe || !opts.DstOffset || dstAppend
- if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset {
+ switch {
+ case srcLock && dstLock:
switch {
case dst.UniqueID < src.UniqueID:
// Acquire dst first.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
if !src.mu.Lock(ctx) {
+ dst.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
case dst.UniqueID > src.UniqueID:
// Acquire src first.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
if !dst.mu.Lock(ctx) {
+ src.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
case dst.UniqueID == src.UniqueID:
// Acquire only one lock; it's the same file. This is a
// bit of a edge case, but presumably it's possible.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
+ srcLock = false // Only need one unlock.
}
// Use both offsets (locked).
opts.DstStart = dst.offset
opts.SrcStart = src.offset
- } else if !dstPipe && !opts.DstOffset {
+ case dstLock:
// Acquire only dst.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
opts.DstStart = dst.offset // Safe: locked.
- } else if !srcPipe && !opts.SrcOffset {
+ case srcLock:
// Acquire only src.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
opts.SrcStart = src.offset // Safe: locked.
}
- // Check append-only mode and the limit.
- if !dstPipe {
+ var err error
+ if dstAppend {
unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append)
defer unlock()
- if dst.Flags().Append {
- if opts.DstOffset {
- // We need to acquire the lock.
- if !dst.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
- }
- defer dst.mu.Unlock()
- }
- // Figure out the appropriate offset to use.
- if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil {
- return 0, err
- }
- }
+ // Figure out the appropriate offset to use.
+ err = dst.offsetForAppend(ctx, &opts.DstStart)
+ }
+ if err == nil && !dstPipe {
// Enforce file limits.
limit, ok := dst.checkLimit(ctx, opts.DstStart)
switch {
case ok && limit == 0:
- return 0, syserror.ErrExceedsFileSizeLimit
+ err = syserror.ErrExceedsFileSizeLimit
case ok && limit < opts.Length:
opts.Length = limit // Cap the write.
}
}
+ if err != nil {
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+ return 0, err
+ }
- // Attempt to do a WriteTo; this is likely the most efficient.
- //
- // The underlying implementation may be able to donate buffers.
- newOpts := SpliceOpts{
- Length: opts.Length,
- SrcStart: opts.SrcStart,
- SrcOffset: !srcPipe,
- Dup: opts.Dup,
- DstStart: opts.DstStart,
- DstOffset: !dstPipe,
+ // Construct readers and writers for the splice. This is used to
+ // provide a safer locking path for the WriteTo/ReadFrom operations
+ // (since they will otherwise go through public interface methods which
+ // conflict with locking done above), and simplifies the fallback path.
+ w := &lockedWriter{
+ Ctx: ctx,
+ File: dst,
+ Offset: opts.DstStart,
}
- n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts)
- if n == 0 && err != nil {
- // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also
- // be more efficient than a copy if buffers are cached or readily
- // available. (It's unlikely that they can actually be donate
- n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts)
+ r := &lockedReader{
+ Ctx: ctx,
+ File: src,
+ Offset: opts.SrcStart,
}
- if n == 0 && err != nil {
- // If we've failed up to here, and at least one of the sources
- // is a pipe or socket, then we can't properly support dup.
- // Return an error indicating that this operation is not
- // supported.
- if (srcPipe || dstPipe) && newOpts.Dup {
- return 0, syserror.EINVAL
- }
- // We failed to splice the files. But that's fine; we just fall
- // back to a slow path in this case. This copies without doing
- // any mode changes, so should still be more efficient.
- var (
- r io.Reader
- w io.Writer
- )
- fw := &lockedWriter{
- Ctx: ctx,
- File: dst,
- }
- if newOpts.DstOffset {
- // Use the provided offset.
- w = secio.NewOffsetWriter(fw, newOpts.DstStart)
- } else {
- // Writes will proceed with no offset.
- w = fw
- }
- fr := &lockedReader{
- Ctx: ctx,
- File: src,
- }
- if newOpts.SrcOffset {
- // Limit to the given offset and length.
- r = io.NewSectionReader(fr, opts.SrcStart, opts.Length)
- } else {
- // Limit just to the given length.
- r = &io.LimitedReader{fr, opts.Length}
- }
+ // Attempt to do a WriteTo; this is likely the most efficient.
+ n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup {
+ // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
+ // more efficient than a copy if buffers are cached or readily
+ // available. (It's unlikely that they can actually be donated).
+ n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length)
+ }
- // Copy between the two.
- n, err = io.Copy(w, r)
+ // Support one last fallback option, but only if at least one of
+ // the source and destination are regular files. This is because
+ // if we block at some point, we could lose data. If the source is
+ // not a pipe then reading is not destructive; if the destination
+ // is a regular file, then it is guaranteed not to block writing.
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) {
+ // Fallback to an in-kernel copy.
+ n, err = io.Copy(w, &io.LimitedReader{
+ R: r,
+ N: opts.Length,
+ })
}
// Update offsets, if required.
@@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
}
+ // Drop locks.
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+
return n, err
}
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 59403d9db..f8bf663bb 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -141,9 +141,10 @@ func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, i
}
// Notify implements ktime.TimerListener.Notify.
-func (t *TimerOperations) Notify(exp uint64) {
+func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
atomic.AddUint64(&t.val, exp)
t.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 8f7eb5757..11b680929 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tmpfs",
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 159fb7c08..69089c8a8 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -324,7 +324,7 @@ type Fifo struct {
// NewFifo creates a new named pipe.
func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
// First create a pipe.
- p := pipe.NewPipe(ctx, true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
// Build pipe InodeOperations.
iops := pipe.NewInodeOperations(ctx, perms, p)
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 291164986..25811f668 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tty",
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index a41101339..b0c286b7a 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
go_template_instance(
@@ -79,7 +81,7 @@ go_test(
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
index 9fddb4c4c..bfc46dfa6 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/BUILD
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index b51f3e18d..91802dc1e 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
}
if !cb.Handle(vfs.Dirent{
- Name: child.diskDirent.FileName(),
- Type: fs.ToDirentType(childType),
- Ino: uint64(child.diskDirent.Inode()),
- Off: fd.off,
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
@@ -301,8 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
return offset, nil
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
// mmap(2) specifies that EACCESS should be returned for non-regular file fds.
return syserror.EACCES
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 907d35b7e..2d50e30aa 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "disklayout",
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 49b57a2d6..1aa2bd6a4 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -33,7 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
const (
@@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) {
// Ignore the inode number and offset of dirents because those are likely to
// change as the underlying image changes.
cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
- return p.String() == "Ino" || p.String() == "Off"
+ return p.String() == "Ino" || p.String() == "NextOff"
}, cmp.Ignore())
if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
t.Errorf("dirents mismatch (-want +got):\n%s", diff)
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index a0065343b..4d18b28cb 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -43,9 +43,6 @@ func (fd *fileDescription) inode() *inode {
return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
}
-// OnClose implements vfs.FileDescriptionImpl.OnClose.
-func (fd *fileDescription) OnClose() error { return nil }
-
// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
return fd.flags, nil
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index ffc76ba5b..aec33e00a 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -152,8 +152,8 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
return offset, nil
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
// TODO(b/134676337): Implement mmap(2).
return syserror.ENODEV
}
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index e06548a98..bdf8705c1 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -105,7 +105,7 @@ func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int6
return 0, syserror.EBADF
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return syserror.EBADF
}
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index d2450e810..7e364c5fd 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go
index c52dc781c..0bd82e480 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/memfs/directory.go
@@ -32,7 +32,7 @@ type directory struct {
childList dentryList
}
-func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode {
+func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode {
dir := &directory{}
dir.inode.init(dir, fs, creds, mode)
dir.inode.nlink = 2 // from "." and parent directory or ".." for root
@@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 0 {
if !cb.Handle(vfs.Dirent{
- Name: ".",
- Type: linux.DT_DIR,
- Ino: vfsd.Impl().(*dentry).inode.ino,
- Off: 0,
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: vfsd.Impl().(*dentry).inode.ino,
+ NextOff: 1,
}) {
return nil
}
@@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
if !cb.Handle(vfs.Dirent{
- Name: "..",
- Type: parentInode.direntType(),
- Ino: parentInode.ino,
- Off: 1,
+ Name: "..",
+ Type: parentInode.direntType(),
+ Ino: parentInode.ino,
+ NextOff: 2,
}) {
return nil
}
@@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Skip other directoryFD iterators.
if child.inode != nil {
if !cb.Handle(vfs.Dirent{
- Name: child.vfsd.Name(),
- Type: child.inode.direntType(),
- Ino: child.inode.ino,
- Off: fd.off,
+ Name: child.vfsd.Name(),
+ Type: child.inode.direntType(),
+ Ino: child.inode.ino,
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go
index 45cd42b3e..b78471c0f 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/memfs/memfs.go
@@ -137,7 +137,7 @@ type inode struct {
impl interface{} // immutable
}
-func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) {
+func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
i.refs = 1
i.mode = uint32(mode)
i.uid = uint32(creds.EffectiveKUID)
diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go
index 55f869798..b7f4853b3 100644
--- a/pkg/sentry/fsimpl/memfs/regular_file.go
+++ b/pkg/sentry/fsimpl/memfs/regular_file.go
@@ -37,7 +37,7 @@ type regularFile struct {
dataLen int64
}
-func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode {
+func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
file := &regularFile{}
file.inode.init(file, fs, creds, mode)
file.inode.nlink = 1 // from parent directory
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 3d8a4deaf..ade6ac946 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD
index f989f2f8b..359468ccc 100644
--- a/pkg/sentry/hostcpu/BUILD
+++ b/pkg/sentry/hostcpu/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -6,6 +7,7 @@ go_library(
name = "hostcpu",
srcs = [
"getcpu_amd64.s",
+ "getcpu_arm64.s",
"hostcpu.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu",
diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s
new file mode 100644
index 000000000..caf9abb89
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_arm64.s
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for
+// the lack of an optimazed way of getting the current CPU number on arm64.
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ MOVW ZR, cpu+0(FP)
+ MOVD $cpu+0(FP), R0
+ MOVD $0x0, R1 // unused
+ MOVD $0x0, R2 // unused
+ MOVD $0xA8, R8 // SYS_GETCPU
+ SVC
+ RET
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 41bee9a22..aba2414d4 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,9 +1,11 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "pending_signals_list",
@@ -83,6 +85,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "uncaught_signal_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":uncaught_signal_proto"],
+)
+
go_proto_library(
name = "uncaught_signal_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto",
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index f46c43128..65427b112 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "epoll_list",
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 1c5f979d4..983ca67ed 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "eventfd",
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 6a31dc044..41f44999c 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "atomicptr_bucket",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8c1f79ab5..3cda03891 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -24,6 +24,7 @@
// TaskSet.mu
// SignalHandlers.mu
// Task.mu
+// runningTasksMu
//
// Locking SignalHandlers.mu in multiple SignalHandlers requires locking
// TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same
@@ -135,6 +136,22 @@ type Kernel struct {
// syslog is the kernel log.
syslog syslog
+ // runningTasksMu synchronizes disable/enable of cpuClockTicker when
+ // the kernel is idle (runningTasks == 0).
+ //
+ // runningTasksMu is used to exclude critical sections when the timer
+ // disables itself and when the first active task enables the timer,
+ // ensuring that tasks always see a valid cpuClock value.
+ runningTasksMu sync.Mutex `state:"nosave"`
+
+ // runningTasks is the total count of tasks currently in
+ // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are
+ // not blocked or stopped.
+ //
+ // runningTasks must be accessed atomically. Increments from 0 to 1 are
+ // further protected by runningTasksMu (see incRunningTasks).
+ runningTasks int64
+
// cpuClock is incremented every linux.ClockTick. cpuClock is used to
// measure task CPU usage, since sampling monotonicClock twice on every
// syscall turns out to be unreasonably expensive. This is similar to how
@@ -150,6 +167,22 @@ type Kernel struct {
// cpuClockTicker increments cpuClock.
cpuClockTicker *ktime.Timer `state:"nosave"`
+ // cpuClockTickerDisabled indicates that cpuClockTicker has been
+ // disabled because no tasks are running.
+ //
+ // cpuClockTickerDisabled is protected by runningTasksMu.
+ cpuClockTickerDisabled bool
+
+ // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the
+ // point it was disabled. It is cached here to avoid a lock ordering
+ // violation with cpuClockTicker.mu when runningTaskMu is held.
+ //
+ // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is
+ // true.
+ //
+ // cpuClockTickerSetting is protected by runningTasksMu.
+ cpuClockTickerSetting ktime.Setting
+
// fdMapUids is an ever-increasing counter for generating FDTable uids.
//
// fdMapUids is mutable, and is accessed using atomic memory operations.
@@ -912,6 +945,102 @@ func (k *Kernel) resumeTimeLocked() {
}
}
+func (k *Kernel) incRunningTasks() {
+ for {
+ tasks := atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // Standard case. Simply increment.
+ if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) {
+ continue
+ }
+ return
+ }
+
+ // Transition from 0 -> 1. Synchronize with other transitions and timer.
+ k.runningTasksMu.Lock()
+ tasks = atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // We're no longer the first task, no need to
+ // re-enable.
+ atomic.AddInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ if !k.cpuClockTickerDisabled {
+ // Timer was never disabled.
+ atomic.StoreInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ // We need to update cpuClock for all of the ticks missed while we
+ // slept, and then re-enable the timer.
+ //
+ // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify
+ // always increments cpuClock by 1 regardless of the number of
+ // expirations as a heuristic to avoid over-accounting in cases of CPU
+ // throttling.
+ //
+ // We want to cover the normal case, when all time should be accounted,
+ // so we increment for all expirations. Throttling is less concerning
+ // here because the ticker is only disabled from Notify. This means
+ // that Notify must schedule and compensate for the throttled period
+ // before the timer is disabled. Throttling while the timer is disabled
+ // doesn't matter, as nothing is running or reading cpuClock anyways.
+ //
+ // S/R also adds complication, as there are two cases. Recall that
+ // monotonicClock will jump forward on restore.
+ //
+ // 1. If the ticker is enabled during save, then on Restore Notify is
+ // called with many expirations, covering the time jump, but cpuClock
+ // is only incremented by 1.
+ //
+ // 2. If the ticker is disabled during save, then after Restore the
+ // first wakeup will call this function and cpuClock will be
+ // incremented by the number of expirations across the S/R.
+ //
+ // These cause very different value of cpuClock. But again, since
+ // nothing was running while the ticker was disabled, those differences
+ // don't matter.
+ setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now())
+ if exp > 0 {
+ atomic.AddUint64(&k.cpuClock, exp)
+ }
+
+ // Now that cpuClock is updated it is safe to allow other tasks to
+ // transition to running.
+ atomic.StoreInt64(&k.runningTasks, 1)
+
+ // N.B. we must unlock before calling Swap to maintain lock ordering.
+ //
+ // cpuClockTickerDisabled need not wait until after Swap to become
+ // true. It is sufficient that the timer *will* be enabled.
+ k.cpuClockTickerDisabled = false
+ k.runningTasksMu.Unlock()
+
+ // This won't call Notify (unless it's been ClockTick since setting.At
+ // above). This means we skip the thread group work in Notify. However,
+ // since nothing was running while we were disabled, none of the timers
+ // could have expired.
+ k.cpuClockTicker.Swap(setting)
+
+ return
+ }
+}
+
+func (k *Kernel) decRunningTasks() {
+ tasks := atomic.AddInt64(&k.runningTasks, -1)
+ if tasks < 0 {
+ panic(fmt.Sprintf("Invalid running count %d", tasks))
+ }
+
+ // Nothing to do. The next CPU clock tick will disable the timer if
+ // there is still nothing running. This provides approximately one tick
+ // of slack in which we can switch back and forth between idle and
+ // active without an expensive transition.
+}
+
// WaitExited blocks until all tasks in k have exited.
func (k *Kernel) WaitExited() {
k.tasks.liveGoroutines.Wait()
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index ebcfaa619..d7a7d1169 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -24,6 +25,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "memory_events_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":memory_events_proto"],
+)
+
go_proto_library(
name = "memory_events_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto",
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 4d15cca85..2ce8952e2 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "buffer_list",
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
index 69ef2a720..95bee2d37 100644
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ b/pkg/sentry/kernel/pipe/buffer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/sentry/safemem"
@@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
return n, err
}
+// WriteFromReader writes to the buffer from an io.Reader.
+func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ dst := b.data[b.write:]
+ if count < int64(len(dst)) {
+ dst = b.data[b.write:][:count]
+ }
+ n, err := r.Read(dst)
+ b.write += n
+ return int64(n), err
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
@@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return n, err
}
+// ReadToWriter reads from the buffer into an io.Writer.
+func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
+ src := b.data[b.read:b.write]
+ if count < int64(len(src)) {
+ src = b.data[b.read:][:count]
+ }
+ n, err := w.Write(src)
+ if !dup {
+ b.read += n
+ }
+ return int64(n), err
+}
+
// bufferPool is a pool for buffers.
var bufferPool = sync.Pool{
New: func() interface{} {
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index adbad7764..16fa80abe 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.
}
func newNamedPipe(t *testing.T) *Pipe {
- return NewPipe(contexttest.Context(t), true, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(true, DefaultPipeSize, usermem.PageSize)
}
func newAnonPipe(t *testing.T) *Pipe {
- return NewPipe(contexttest.Context(t), false, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(false, DefaultPipeSize, usermem.PageSize)
}
// assertRecvBlocks ensures that a recv attempt on c blocks for at least
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 247e2928e..8e4e8e82e 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -99,7 +98,7 @@ type Pipe struct {
// NewPipe initializes and returns a pipe.
//
// N.B. The size and atomicIOBytes will be bounded.
-func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
+func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
if sizeBytes < MinimumPipeSize {
sizeBytes = MinimumPipeSize
}
@@ -122,7 +121,7 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64)
// NewConnectedPipe initializes a pipe and returns a pair of objects
// representing the read and write ends of the pipe.
func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
- p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes)
+ p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes)
// Build an fs.Dirent for the pipe which will be shared by both
// returned files.
@@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F
}
}
+type readOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit limits subsequence reads.
+ limit func(int64)
+
+ // read performs the actual read operation.
+ read func(*buffer) (int64, error)
+}
+
// read reads data from the pipe into dst and returns the number of bytes
// read, or returns ErrWouldBlock if the pipe is empty.
//
// Precondition: this pipe must have readers.
-func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
// Don't block for a zero-length read even if the pipe is empty.
- if dst.NumBytes() == 0 {
+ if ops.left() == 0 {
return 0, nil
}
@@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Limit how much we consume.
- if dst.NumBytes() > p.size {
- dst = dst.TakeFirst64(p.size)
+ if ops.left() > p.size {
+ ops.limit(p.size)
}
done := int64(0)
- for dst.NumBytes() > 0 {
+ for ops.left() > 0 {
// Pop the first buffer.
first := p.data.Front()
if first == nil {
@@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := dst.CopyOutFrom(ctx, first)
+ n, err := ops.read(first)
done += int64(n)
p.size -= n
- dst = dst.DropFirst64(n)
// Empty buffer?
if first.Empty() {
@@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
return done, nil
}
+// dup duplicates all data from this pipe into the given writer.
+//
+// There is no blocking behavior implemented here. The writer may propagate
+// some blocking error. All the writes must be complete writes.
+func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Is the pipe empty?
+ if p.size == 0 {
+ if !p.HasWriters() {
+ // See above.
+ return 0, nil
+ }
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Limit how much we consume.
+ if ops.left() > p.size {
+ ops.limit(p.size)
+ }
+
+ done := int64(0)
+ for buf := p.data.Front(); buf != nil; buf = buf.Next() {
+ n, err := ops.read(buf)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ }
+
+ return done, nil
+}
+
+type writeOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit should limit subsequent writes.
+ limit func(int64)
+
+ // write should write to the provided buffer.
+ write func(*buffer) (int64, error)
+}
+
// write writes data from sv into the pipe and returns the number of bytes
// written. If no bytes are written because the pipe is full (or has less than
// atomicIOBytes free capacity), write returns ErrWouldBlock.
//
// Precondition: this pipe must have writers.
-func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
- wanted := src.NumBytes()
+ wanted := ops.left()
if avail := p.max - p.size; wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
- // Limit to the available capacity.
- src = src.TakeFirst64(avail)
+ ops.limit(avail)
}
done := int64(0)
- for src.NumBytes() > 0 {
+ for ops.left() > 0 {
// Need a new buffer?
last := p.data.Back()
if last == nil || last.Full() {
@@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := src.CopyInTo(ctx, last)
+ n, err := ops.write(last)
done += int64(n)
p.size += n
- src = src.DropFirst64(n)
// Handle errors.
if err != nil {
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index f69dbf27b..7c307f013 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"math"
"syscall"
@@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() {
// Read implements fs.FileOperations.Read.
func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.read(ctx, dst)
+ n, err := rw.Pipe.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, buf)
+ dst = dst.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := buf.ReadToWriter(w, count, dup)
+ count -= n
+ return n, err
+ },
+ }
+ if dup {
+ // There is no notification for dup operations.
+ return rw.Pipe.dup(ctx, ops)
+ }
+ n, err := rw.Pipe.read(ctx, ops)
if n > 0 {
rw.Pipe.Notify(waiter.EventOut)
}
@@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ
// Write implements fs.FileOperations.Write.
func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, src)
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := src.CopyInTo(ctx, buf)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := buf.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
if n > 0 {
rw.Pipe.Notify(waiter.EventIn)
}
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
index c5d095af7..2e861a5a8 100644
--- a/pkg/sentry/kernel/posixtimer.go
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -117,9 +117,9 @@ func (it *IntervalTimer) signalRejectedLocked() {
}
// Notify implements ktime.TimerListener.Notify.
-func (it *IntervalTimer) Notify(exp uint64) {
+func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
if it.target == nil {
- return
+ return ktime.Setting{}, false
}
it.target.tg.pidns.owner.mu.RLock()
@@ -129,7 +129,7 @@ func (it *IntervalTimer) Notify(exp uint64) {
if it.sigpending {
it.overrunCur += exp
- return
+ return ktime.Setting{}, false
}
// sigpending must be set before sendSignalTimerLocked() so that it can be
@@ -148,6 +148,8 @@ func (it *IntervalTimer) Notify(exp uint64) {
if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil {
it.signalRejectedLocked()
}
+
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy. Users of Timer should call
diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD
index 1725b8562..98ea7a0d8 100644
--- a/pkg/sentry/kernel/sched/BUILD
+++ b/pkg/sentry/kernel/sched/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 36edf10f3..80e5e5da3 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "waiter_list",
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index e5f297478..047b5214d 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -328,8 +328,14 @@ func (tg *ThreadGroup) createSession() error {
childTG.processGroup.incRefWithParent(pg)
childTG.processGroup.decRefWithParent(oldParentPG)
})
- tg.processGroup.decRefWithParent(oldParentPG)
+ // If tg.processGroup is an orphan, decRefWithParent will lock
+ // the signal mutex of each thread group in tg.processGroup.
+ // However, tg's signal mutex may already be locked at this
+ // point. We change tg's process group before calling
+ // decRefWithParent to avoid locking tg's signal mutex twice.
+ oldPG := tg.processGroup
tg.processGroup = pg
+ oldPG.decRefWithParent(oldParentPG)
} else {
// The current process group may be nil only in the case of an
// unparented thread group (i.e. the init process). This would
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
new file mode 100644
index 000000000..50b69d154
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -0,0 +1,22 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_library(
+ name = "signalfd",
+ srcs = ["signalfd.go"],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/usermem",
+ "//pkg/syserror",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
new file mode 100644
index 000000000..4b08d7d72
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -0,0 +1,140 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package signalfd provides an implementation of signal file descriptors.
+package signalfd
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalOperations represent a file with signalfd semantics.
+//
+// +stateify savable
+type SignalOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // target is the original task target.
+ //
+ // The semantics here are a bit broken. Linux will always use current
+ // for all reads, regardless of where the signalfd originated. We can't
+ // do exactly that because we need to plumb the context through
+ // EventRegister in order to support proper blocking behavior. This
+ // will undoubtedly become very complicated quickly.
+ target *kernel.Task
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // mask is the signal mask. Protected by mu.
+ mask linux.SignalSet
+}
+
+// New creates a new signalfd object with the supplied mask.
+func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // No task context? Not valid.
+ return nil, syserror.EINVAL
+ }
+ // name matches fs/signalfd.c:signalfd4.
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]")
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{
+ target: t,
+ mask: mask,
+ }), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *SignalOperations) Release() {}
+
+// Mask returns the signal mask.
+func (s *SignalOperations) Mask() linux.SignalSet {
+ s.mu.Lock()
+ mask := s.mask
+ s.mu.Unlock()
+ return mask
+}
+
+// SetMask sets the signal mask.
+func (s *SignalOperations) SetMask(mask linux.SignalSet) {
+ s.mu.Lock()
+ s.mask = mask
+ s.mu.Unlock()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ // Attempt to dequeue relevant signals.
+ info, err := s.target.Sigtimedwait(s.Mask(), 0)
+ if err != nil {
+ // There must be no signal available.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Copy out the signal info using the specified format.
+ var buf [128]byte
+ binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ Signo: uint32(info.Signo),
+ Errno: info.Errno,
+ Code: info.Code,
+ PID: uint32(info.Pid()),
+ UID: uint32(info.Uid()),
+ Status: info.Status(),
+ Overrun: uint32(info.Overrun()),
+ Addr: info.Addr(),
+ })
+ n, err := dst.CopyOut(ctx, buf[:])
+ return int64(n), err
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 && s.target.PendingSignals()&s.Mask() != 0 {
+ return waiter.EventIn // Pending signals.
+ }
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) {
+ // Register for the signal set; ignore the passed events.
+ s.target.SignalRegister(entry, waiter.EventMask(s.Mask()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SignalOperations) EventUnregister(entry *waiter.Entry) {
+ // Unregister the original entry.
+ s.target.SignalUnregister(entry)
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index e91f82bb3..c82ef5486 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
"gvisor.dev/gvisor/third_party/gvsync"
)
@@ -133,6 +134,13 @@ type Task struct {
// signalStack is exclusive to the task goroutine.
signalStack arch.SignalStack
+ // signalQueue is a set of registered waiters for signal-related events.
+ //
+ // signalQueue is protected by the signalMutex. Note that the task does
+ // not implement all queue methods, specifically the readiness checks.
+ // The task only broadcast a notification on signal delivery.
+ signalQueue waiter.Queue `state:"zerovalue"`
+
// If groupStopPending is true, the task should participate in a group
// stop in the interrupt path.
//
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index 2a2e6f662..dd69939f9 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -15,6 +15,7 @@
package kernel
import (
+ "runtime"
"time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -121,6 +122,17 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
// Deactive our address space, we don't need it.
interrupt := t.SleepStart()
+ // If the request is not completed, but the timer has already expired,
+ // then ensure that we run through a scheduler cycle. This is because
+ // we may see applications relying on timer slack to yield the thread.
+ // For example, they may attempt to sleep for some number of nanoseconds,
+ // and expect that this will actually yield the CPU and sleep for at
+ // least microseconds, e.g.:
+ // https://github.com/LMAX-Exchange/disruptor/commit/6ca210f2bcd23f703c479804d583718e16f43c07
+ if len(timerChan) > 0 {
+ runtime.Gosched()
+ }
+
select {
case <-C:
t.SleepFinish(true)
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index 78ff14b20..ce3e6ef28 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -465,8 +465,8 @@ func (t *Task) SetKeepCaps(k bool) {
// disables the features we don't support anyway, is always set. This
// drastically simplifies this function.
//
-// - We don't implement AT_SECURE, because no_new_privs always being set means
-// that the conditions that require AT_SECURE never arise. (Compare Linux's
+// - We don't set AT_SECURE = 1, because no_new_privs always being set means
+// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's
// security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().)
//
// - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index e76c069b0..8b148db35 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -126,12 +126,22 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) {
t.gosched.Timestamp = now
t.gosched.State = state
t.goschedSeq.EndWrite()
+
+ if state != TaskGoroutineRunningApp {
+ // Task is blocking/stopping.
+ t.k.decRunningTasks()
+ }
}
// Preconditions: The caller must be running on the task goroutine, and leaving
// a state indicated by a previous call to
// t.accountTaskGoroutineEnter(state).
func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) {
+ if state != TaskGoroutineRunningApp {
+ // Task is unblocking/continuing.
+ t.k.incRunningTasks()
+ }
+
now := t.k.CPUClockNow()
if t.gosched.State != state {
panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys))
@@ -330,7 +340,7 @@ func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker {
}
// Notify implements ktime.TimerListener.Notify.
-func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
+func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
// Only increment cpuClock by 1 regardless of the number of expirations.
// This approximately compensates for cases where thread throttling or bad
// Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and
@@ -426,6 +436,27 @@ func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
tgs[i] = nil
}
ticker.tgs = tgs[:0]
+
+ // If nothing is running, we can disable the timer.
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks == 0 {
+ ticker.k.runningTasksMu.Lock()
+ defer ticker.k.runningTasksMu.Unlock()
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks != 0 {
+ // Raced with a 0 -> 1 transition.
+ return setting, false
+ }
+
+ // Stop the timer. We must cache the current setting so the
+ // kernel can access it without violating the lock order.
+ ticker.k.cpuClockTickerSetting = setting
+ ticker.k.cpuClockTickerDisabled = true
+ setting.Enabled = false
+ return setting, true
+ }
+
+ return setting, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 266959a07..39cd1340d 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -28,6 +28,7 @@ import (
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// SignalAction is an internal signal action.
@@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) {
//
// Preconditions: The signal mutex must be locked.
func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
+ // Notify that the signal is queued.
+ t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig)))
+
// - Do not choose tasks that are blocking the signal.
if linux.SignalSetOf(sig)&t.signalMask != 0 {
return false
@@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
t.tg.signalHandlers.mu.Unlock()
return t.deliverSignal(info, act)
}
+
+// SignalRegister registers a waiter for pending signals.
+func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventRegister(e, mask)
+ t.tg.signalHandlers.mu.Unlock()
+}
+
+// SignalUnregister unregisters a waiter for pending signals.
+func (t *Task) SignalUnregister(e *waiter.Entry) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventUnregister(e)
+ t.tg.signalHandlers.mu.Unlock()
+}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 0eef24bfb..72568d296 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -511,8 +511,9 @@ type itimerRealListener struct {
}
// Notify implements ktime.TimerListener.Notify.
-func (l *itimerRealListener) Notify(exp uint64) {
+func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM))
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index aa6c75d25..107394183 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -280,13 +280,16 @@ func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
// A TimerListener receives expirations from a Timer.
type TimerListener interface {
// Notify is called when its associated Timer expires. exp is the number of
- // expirations.
+ // expirations. setting is the next timer Setting.
//
// Notify is called with the associated Timer's mutex locked, so Notify
// must not take any locks that precede Timer.mu in lock order.
//
+ // If Notify returns true, the timer will use the returned setting
+ // rather than the passed one.
+ //
// Preconditions: exp > 0.
- Notify(exp uint64)
+ Notify(exp uint64, setting Setting) (newSetting Setting, update bool)
// Destroy is called when the timer is destroyed.
Destroy()
@@ -533,7 +536,9 @@ func (t *Timer) Tick() {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
}
@@ -588,7 +593,9 @@ func (t *Timer) Get() (Time, Setting) {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, s
@@ -620,7 +627,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
}
oldS, oldExp := t.setting.At(now)
if oldExp > 0 {
- t.listener.Notify(oldExp)
+ t.listener.Notify(oldExp, oldS)
+ // N.B. The returned Setting doesn't matter because we're about
+ // to overwrite.
}
if f != nil {
f()
@@ -628,7 +637,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
newS, newExp := s.At(now)
t.setting = newS
if newExp > 0 {
- t.listener.Notify(newExp)
+ if newS, ok := t.listener.Notify(newExp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, oldS
@@ -683,11 +694,13 @@ func NewChannelNotifier() (TimerListener, <-chan struct{}) {
}
// Notify implements ktime.TimerListener.Notify.
-func (c *ChannelNotifier) Notify(uint64) {
+func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) {
select {
case c.tchan <- struct{}{}:
default:
}
+
+ return Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy and will close the channel.
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 40025d62d..59649c770 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "limits",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index bc5b841fb..2d9251e92 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.
return syserror.ENOEXEC
}
+ // N.B. Linux uses vm_brk_flags to map these pages, which only
+ // honors the X bit, always mapping at least RW. ignoring These
+ // pages are not included in the final brk region.
+ prot := usermem.ReadWrite
+ if phdr.Flags&elf.PF_X == elf.PF_X {
+ prot.Execute = true
+ }
+
if _, err := m.MMap(ctx, memmap.MMapOpts{
Length: uint64(anonSize),
Addr: anonAddr,
// Fixed without Unmap will fail the mmap if something is
// already at addr.
- Fixed: true,
- Private: true,
- // N.B. Linux uses vm_brk to map these pages, ignoring
- // the segment protections, instead always mapping RW.
- // These pages are not included in the final brk
- // region.
- Perms: usermem.ReadWrite,
+ Fixed: true,
+ Private: true,
+ Perms: prot,
MaxPerms: usermem.AnyAccess,
}); err != nil {
ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
@@ -464,7 +468,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
// base address big enough to fit all segments, so we first create a
// mapping for the total size just to find a region that is big enough.
//
- // It is safe to unmap it immediately with racing with another mapping
+ // It is safe to unmap it immediately without racing with another mapping
// because we are the only one in control of the MemoryManager.
//
// Note that the vaddr of the first PT_LOAD segment is ignored when
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index f6f1ae762..089d1635b 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -308,6 +308,9 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r
arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
+ // The conditions that require AT_SECURE = 1 never arise. See
+ // kernel.Task.updateCredsForExecLocked.
+ arch.AuxEntry{linux.AT_SECURE, 0},
arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC},
arch.AuxEntry{linux.AT_EXECFN, execfn},
arch.AuxEntry{linux.AT_RANDOM, random},
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 29c14ec56..9687e7e76 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "mappable_range",
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 072745a08..b35c8c673 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "file_refcount_set",
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 858f895f2..3fd904c67 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "evictable_range",
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
index eeb634644..b6d008dbe 100644
--- a/pkg/sentry/platform/interrupt/BUILD
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index ad8b95744..31fa48ec5 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -54,6 +55,7 @@ go_test(
],
embed = [":kvm"],
tags = [
+ "manual",
"nogotsan",
"requires-kvm",
],
diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD
index 77a449a8b..b0e45f159 100644
--- a/pkg/sentry/platform/kvm/testutil/BUILD
+++ b/pkg/sentry/platform/kvm/testutil/BUILD
@@ -9,6 +9,8 @@ go_library(
"testutil.go",
"testutil_amd64.go",
"testutil_amd64.s",
+ "testutil_arm64.go",
+ "testutil_arm64.s",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil",
visibility = ["//pkg/sentry/platform/kvm:__pkg__"],
diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go
index 6cf2359a3..5c1efa0fd 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil.go
@@ -41,9 +41,6 @@ func TwiddleRegsFault()
// TwiddleRegsSyscall twiddles registers then executes a syscall.
func TwiddleRegsSyscall()
-// TwiddleSegments reads segments into known registers.
-func TwiddleSegments()
-
// FloatingPointWorks is a floating point test.
//
// It returns true or false.
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
index 203d71528..4c108abbf 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
@@ -21,6 +21,9 @@ import (
"syscall"
)
+// TwiddleSegments reads segments into known registers.
+func TwiddleSegments()
+
// SetTestTarget sets the rip appropriately.
func SetTestTarget(regs *syscall.PtraceRegs, fn func()) {
regs.Rip = uint64(reflect.ValueOf(fn).Pointer())
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
new file mode 100644
index 000000000..40b2e4acc
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -0,0 +1,59 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package testutil
+
+import (
+ "fmt"
+ "reflect"
+ "syscall"
+)
+
+// SetTestTarget sets the rip appropriately.
+func SetTestTarget(regs *syscall.PtraceRegs, fn func()) {
+ regs.Pc = uint64(reflect.ValueOf(fn).Pointer())
+}
+
+// SetTouchTarget sets rax appropriately.
+func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
+ if target != nil {
+ regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer())
+ } else {
+ regs.Regs[8] = 0
+ }
+}
+
+// RewindSyscall rewinds a syscall RIP.
+func RewindSyscall(regs *syscall.PtraceRegs) {
+ regs.Pc -= 4
+}
+
+// SetTestRegs initializes registers to known values.
+func SetTestRegs(regs *syscall.PtraceRegs) {
+ for i := 0; i <= 30; i++ {
+ regs.Regs[i] = uint64(i) + 1
+ }
+}
+
+// CheckTestRegs checks that registers were twiddled per TwiddleRegs.
+func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) {
+ for i := 0; i <= 30; i++ {
+ if need := ^uint64(i + 1); regs.Regs[i] != need {
+ err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need)
+ }
+ }
+ return
+}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
new file mode 100644
index 000000000..2cd28b2d2
--- /dev/null
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -0,0 +1,91 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+// test_util_arm64.s provides ARM64 test functions.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define SYS_GETPID 172
+
+// This function simulates the getpid syscall.
+TEXT ·Getpid(SB),NOSPLIT,$0
+ NO_LOCAL_POINTERS
+ MOVD $SYS_GETPID, R8
+ SVC
+ RET
+
+TEXT ·Touch(SB),NOSPLIT,$0
+start:
+ MOVD 0(R8), R1
+ MOVD $SYS_GETPID, R8 // getpid
+ SVC
+ B start
+
+TEXT ·HaltLoop(SB),NOSPLIT,$0
+start:
+ HLT
+ B start
+
+// This function simulates a loop of syscall.
+TEXT ·SyscallLoop(SB),NOSPLIT,$0
+start:
+ SVC
+ B start
+
+TEXT ·SpinLoop(SB),NOSPLIT,$0
+start:
+ B start
+
+// MVN: bitwise logical NOT
+// This case simulates an application that modified R0-R30.
+#define TWIDDLE_REGS() \
+ MVN R0, R0; \
+ MVN R1, R1; \
+ MVN R2, R2; \
+ MVN R3, R3; \
+ MVN R4, R4; \
+ MVN R5, R5; \
+ MVN R6, R6; \
+ MVN R7, R7; \
+ MVN R8, R8; \
+ MVN R9, R9; \
+ MVN R10, R10; \
+ MVN R11, R11; \
+ MVN R12, R12; \
+ MVN R13, R13; \
+ MVN R14, R14; \
+ MVN R15, R15; \
+ MVN R16, R16; \
+ MVN R17, R17; \
+ MVN R18_PLATFORM, R18_PLATFORM; \
+ MVN R19, R19; \
+ MVN R20, R20; \
+ MVN R21, R21; \
+ MVN R22, R22; \
+ MVN R23, R23; \
+ MVN R24, R24; \
+ MVN R25, R25; \
+ MVN R26, R26; \
+ MVN R27, R27; \
+ MVN g, g; \
+ MVN R29, R29; \
+ MVN R30, R30;
+
+TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
+ TWIDDLE_REGS()
+ SVC
+ RET // never reached
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 47957bb3b..72c7ec564 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) {
cpu: ^uint32(0),
}, nil
}
+
+// getEventMessage retrieves a message about the ptrace event that just happened.
+func (t *thread) getEventMessage() (uintptr, error) {
+ var msg uintptr
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETEVENTMSG,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(&msg)),
+ 0, 0)
+ if errno != 0 {
+ return msg, errno
+ }
+ return msg, nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 6bf7cd097..9f0ecfbe4 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -267,7 +267,7 @@ func (s *subprocess) newThread() *thread {
// attach attaches to the thread.
func (t *thread) attach() {
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("unable to attach: %v", errno))
}
@@ -355,7 +355,8 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
}
if stopSig == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
- t.dumpAndPanic("wait failed: the process exited")
+ msg, err := t.getEventMessage()
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
}
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
@@ -416,7 +417,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
for {
// Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -426,12 +427,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
break
} else {
// Some other signal caused a thread stop; ignore.
+ if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD {
+ log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig)
+ }
continue
}
}
// Complete the actual system call.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -522,17 +526,17 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
for {
// Start running until the next system call.
if isSingleStepping(regs) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU_SINGLESTEP,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index f09b0b3d0..c075b5f91 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -53,7 +53,7 @@ func probeSeccomp() bool {
for {
// Attempt an emulation.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -266,7 +266,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// Enable cpuid-faulting; this may fail on older kernels or hardware,
// so we just disregard the result. Host CPUID will be enabled.
- syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0)
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
// Call the stub; should not return.
stubCall(stubStart, ppid)
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 3b95af617..ea090b686 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD
index 924d8a6d6..6769cd0a5 100644
--- a/pkg/sentry/platform/safecopy/BUILD
+++ b/pkg/sentry/platform/safecopy/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD
index fd6dc8e6e..884020f7b 100644
--- a/pkg/sentry/safemem/BUILD
+++ b/pkg/sentry/safemem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index eace3766d..c303435d5 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -23,7 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// TODO(b/34161764): Move to pkg/abi/linux along with definitions in
+// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
// pkg/sentry/arch.
type sigaction struct {
handler uintptr
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 45ebb2a0e..7da68384e 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
+ "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 9e2e12799..445080aa4 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "port",
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index d0aab293d..b2732ca29 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -416,6 +417,24 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
Peek: flags&linux.MSG_PEEK != 0,
}
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned. However, the memory manager for the destination will not read
+ // the endpoint if the destination is zero length.
+ //
+ // In order for the endpoint to be read when the destination size is zero,
+ // we must cause a read of the endpoint by using a separate fake zero
+ // length block sequence and calling the EndpointReader directly.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ // Perform a read to a zero byte block sequence. We can ignore the
+ // original destination since it was zero bytes. The length returned by
+ // ReadToBlocks is ignored and we return the full message length to comply
+ // with MSG_TRUNC.
+ _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0))))
+ return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
@@ -499,6 +518,9 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
PortID: uint32(ms.PortID),
})
+ // Add the dump_done_errno payload.
+ m.Put(int64(0))
+
_, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
return err
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/netstack/BUILD
index e927821e1..60523f79a 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -3,15 +3,15 @@ package(licenses = ["notice"])
load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
- name = "epsocket",
+ name = "netstack",
srcs = [
"device.go",
- "epsocket.go",
+ "netstack.go",
"provider.go",
"save_restore.go",
"stack.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/epsocket",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack",
visibility = [
"//pkg/sentry:internal",
],
diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/netstack/device.go
index 85484d5b1..fbeb89fb8 100644
--- a/pkg/sentry/socket/epsocket/device.go
+++ b/pkg/sentry/socket/netstack/device.go
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import "gvisor.dev/gvisor/pkg/sentry/device"
-// epsocketDevice is the endpoint socket virtual device.
-var epsocketDevice = device.NewAnonDevice()
+// netstackDevice is the endpoint socket virtual device.
+var netstackDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/netstack/netstack.go
index 635042263..0ae573b45 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package epsocket provides an implementation of the socket.Socket interface
+// Package netstack provides an implementation of the socket.Socket interface
// that is backed by a tcpip.Endpoint.
//
// It does not depend on any particular endpoint implementation, and thus can
@@ -22,17 +22,20 @@
// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
// this operation.
-package epsocket
+package netstack
import (
"bytes"
+ "io"
"math"
+ "reflect"
"sync"
"syscall"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -52,6 +55,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -133,11 +137,13 @@ var Metrics = tcpip.Stats{
},
},
IP: tcpip.IPStats{
- PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
- InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
- PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
- PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
- OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
+ InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
+ PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
+ PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
+ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
+ MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
},
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
@@ -151,6 +157,7 @@ var Metrics = tcpip.Stats{
ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."),
InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."),
SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."),
+ SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."),
ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."),
ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."),
Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."),
@@ -166,13 +173,18 @@ var Metrics = tcpip.Stats{
UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."),
ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."),
MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
- PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent via sendUDP."),
+ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
+ PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
},
}
+// DefaultTTL is linux's default TTL. All network protocols in all stacks used
+// with this package must have this value set as their default TTL.
+const DefaultTTL = 64
+
const sizeOfInt32 int = 4
-var errStackType = syserr.New("expected but did not receive an epsocket.Stack", linux.EINVAL)
+var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
// ntohs converts a 16-bit number from network byte order to host byte order. It
// assumes that the host is little endian.
@@ -205,6 +217,10 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
@@ -224,7 +240,6 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileNoFsync `state:"nosave"`
fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
*waiter.Queue
@@ -255,8 +270,8 @@ type SocketOperations struct {
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
- // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket
- // level, because it takes into account data from readView.
+ // sockOptInq corresponds to TCP_INQ. It is implemented at this level
+ // because it takes into account data from readView.
sockOptInq bool
}
@@ -268,7 +283,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
}
}
- dirent := socket.NewDirent(t, epsocketDevice)
+ dirent := socket.NewDirent(t, netstackDevice)
defer dirent.DecRef()
return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{
Queue: queue,
@@ -409,17 +424,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
return int64(n), nil
}
-// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand
-// based on the requested size.
+// WriteTo implements fs.FileOperations.WriteTo.
+func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+ s.readMu.Lock()
+
+ // Copy as much data as possible.
+ done := int64(0)
+ for count > 0 {
+ // This may return a blocking error.
+ if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
+ return done, err.ToError()
+ }
+
+ // Write to the underlying file.
+ n, err := dst.Write(s.readView)
+ done += int64(n)
+ count -= int64(n)
+ if dup {
+ // That's all we support for dup. This is generally
+ // supported by any Linux system calls, but the
+ // expectation is that now a caller will call read to
+ // actually remove these bytes from the socket.
+ break
+ }
+
+ // Drop that part of the view.
+ s.readView.TrimFront(n)
+ if err != nil {
+ s.readMu.Unlock()
+ return done, err
+ }
+ }
+
+ s.readMu.Unlock()
+ return done, nil
+}
+
+// ioSequencePayload implements tcpip.Payload.
+//
+// t copies user memory bytes on demand based on the requested size.
type ioSequencePayload struct {
ctx context.Context
src usermem.IOSequence
}
-// Get implements tcpip.Payload.
-func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
- if size > i.Size() {
- size = i.Size()
+// FullPayload implements tcpip.Payloader.FullPayload
+func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
+ return i.Payload(int(i.src.NumBytes()))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if max := int(i.src.NumBytes()); size > max {
+ size = max
}
v := buffer.NewView(size)
if _, err := i.src.CopyIn(i.ctx, v); err != nil {
@@ -428,11 +486,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
return v, nil
}
-// Size implements tcpip.Payload.
-func (i *ioSequencePayload) Size() int {
- return int(i.src.NumBytes())
-}
-
// DropFirst drops the first n bytes from underlying src.
func (i *ioSequencePayload) DropFirst(n int) {
i.src = i.src.DropFirst(int(n))
@@ -466,6 +519,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
return int64(n), nil
}
+// readerPayload implements tcpip.Payloader.
+//
+// It allocates a view and reads from a reader on-demand, based on available
+// capacity in the endpoint.
+type readerPayload struct {
+ ctx context.Context
+ r io.Reader
+ count int64
+ err error
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload.
+func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
+ return r.Payload(int(r.count))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if size > int(r.count) {
+ size = int(r.count)
+ }
+ v := buffer.NewView(size)
+ n, err := r.r.Read(v)
+ if n > 0 {
+ // We ignore the error here. It may re-occur on subsequent
+ // reads, but for now we can enqueue some amount of data.
+ r.count -= int64(n)
+ return v[:n], nil
+ }
+ if err == syserror.ErrWouldBlock {
+ return nil, tcpip.ErrWouldBlock
+ } else if err != nil {
+ r.err = err // Save for propation.
+ return nil, tcpip.ErrBadAddress
+ }
+
+ // There is no data and no error. Return an error, which will propagate
+ // r.err, which will be nil. This is the desired result: (0, nil).
+ return nil, tcpip.ErrBadAddress
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ f := &readerPayload{ctx: ctx, r: r, count: count}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
+ Atomic: true, // See above.
+ })
+ }
+ if err == tcpip.ErrWouldBlock {
+ return n, syserror.ErrWouldBlock
+ } else if err != nil {
+ return int64(n), f.err // Propagate error.
+ }
+
+ return int64(n), nil
+}
+
// Readiness returns a mask of ready events for socket s.
func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
r := s.Endpoint.Readiness(mask)
@@ -643,7 +768,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// tcpip.Endpoint.
func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
- // implemented specifically for epsocket.SocketOperations rather than
+ // implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
// options where the implementation is not shared, as unix sockets need
// their own support for SO_TIMESTAMP.
@@ -716,7 +841,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
return getSockOptIPv6(t, ep, name, outLen)
case linux.SOL_IP:
- return getSockOptIP(t, ep, name, outLen)
+ return getSockOptIP(t, ep, name, outLen, family)
case linux.SOL_UDP,
linux.SOL_ICMPV6,
@@ -774,8 +899,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -790,8 +915,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -825,6 +950,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if len(v) == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return append([]byte(v), 0), nil
+
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1039,6 +1177,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_TCLASS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return make([]byte, 0), nil
+ }
+ var v tcpip.IPv6TrafficClassOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ uintv := uint32(v)
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ // Handle cases where outLen is lesser than sizeOfInt32.
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+ return ib, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1046,8 +1203,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
switch name {
+ case linux.IP_TTL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TTLOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // Fill in the default value, if needed.
+ if v == 0 {
+ v = DefaultTTL
+ }
+
+ return int32(v), nil
+
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1089,6 +1263,20 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac
}
return int32(0), nil
+ case linux.IP_TOS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return []byte(nil), nil
+ }
+ var v tcpip.IPv4TOSOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if outLen < sizeOfInt32 {
+ return uint8(v), nil
+ }
+ return int32(v), nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1099,7 +1287,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac
// tcpip.Endpoint.
func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
- // implemented specifically for epsocket.SocketOperations rather than
+ // implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
// options where the implementation is not shared, as unix sockets need
// their own support for SO_TIMESTAMP.
@@ -1162,7 +1350,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1170,7 +1358,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -1188,6 +1376,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -1380,6 +1575,19 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_TCLASS:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < -1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ if v == -1 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1511,6 +1719,30 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
t.Kernel().EmitUnimplementedEvent(t)
return syserr.ErrInvalidArgument
+ case linux.IP_TTL:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ // -1 means default TTL.
+ if v == -1 {
+ v = 0
+ } else if v < 1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v)))
+
+ case linux.IP_TOS:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1534,9 +1766,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_RECVTOS,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
- linux.IP_TOS,
linux.IP_TRANSPARENT,
- linux.IP_TTL,
linux.IP_UNBLOCK_SOURCE,
linux.IP_UNICAST_IF,
linux.IP_XFRM_POLICY,
@@ -2057,7 +2287,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= int64(v.Size()) || dontWait) {
+ if err == nil && (n >= v.src.NumBytes() || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2082,7 +2312,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
@@ -2098,10 +2328,11 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
// Ioctl implements fs.FileOperations.Ioctl.
func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint
+ // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
- if int(args[1].Int()) == syscall.SIOCGSTAMP {
+ switch args[1].Int() {
+ case syscall.SIOCGSTAMP:
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
@@ -2113,6 +2344,25 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
AddressSpaceActive: true,
})
return 0, err
+
+ case linux.TIOCINQ:
+ v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ // Add bytes removed from the endpoint but not yet sent to the caller.
+ v += len(s.readView)
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
}
return Ioctl(ctx, s.Endpoint, io, args)
@@ -2184,9 +2434,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCOUTQ:
- var v tcpip.SendQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
@@ -2381,7 +2631,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
// Flag values and meanings are described in greater detail in netdevice(7) in
// the SIOCGIFFLAGS section.
func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) {
- // epsocket should only ever be passed an epsocket.Stack.
+ // We should only ever be passed a netstack.Stack.
epstack, ok := stack.(*Stack)
if !ok {
return 0, errStackType
@@ -2421,7 +2671,8 @@ func (s *SocketOperations) State() uint32 {
return 0
}
- if !s.isPacketBased() {
+ switch {
+ case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -2450,9 +2701,26 @@ func (s *SocketOperations) State() uint32 {
// Internal or unknown state.
return 0
}
+ case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ // UDP socket.
+ switch udp.EndpointState(s.Endpoint.State()) {
+ case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ return linux.TCP_CLOSE
+ case udp.StateConnected:
+ return linux.TCP_ESTABLISHED
+ default:
+ return 0
+ }
+ case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ // TODO(b/112063468): Export states for ICMP sockets.
+ case s.skType == linux.SOCK_RAW:
+ // TODO(b/112063468): Export states for raw sockets.
+ default:
+ // Unknown transport protocol, how did we make this socket?
+ log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem())
+ return 0
}
- // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets.
return 0
}
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/netstack/provider.go
index 421f93dc4..357a664cc 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import (
"syscall"
@@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, true, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrNotPermitted
}
switch protocol {
diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go
index f7b8c10cc..c7aaf722a 100644
--- a/pkg/sentry/socket/epsocket/save_restore.go
+++ b/pkg/sentry/socket/netstack/save_restore.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/netstack/stack.go
index 7cf7ff735..fda0156e5 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import (
"gvisor.dev/gvisor/pkg/abi/linux"
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 5061dcbde..3a6baa308 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -49,6 +50,14 @@ proto_library(
],
)
+cc_proto_library(
+ name = "syscall_rpc_cc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [":syscall_rpc_proto"],
+)
+
go_proto_library(
name = "syscall_rpc_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index da9977fde..830f4da10 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -24,7 +24,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
- "//pkg/sentry/socket/epsocket",
+ "//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index 1c71609e2..e27b1c714 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -161,7 +161,8 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
if q.dataList.Front() == nil {
err := syserr.ErrWouldBlock
if q.closed {
- if err = syserr.ErrClosedForReceive; q.unread {
+ err = syserr.ErrClosedForReceive
+ if q.unread {
err = syserr.ErrConnectionReset
}
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index a103b1aab..529a7a7a9 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,6 +175,10 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
@@ -847,6 +851,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -862,65 +870,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrQueueSizeNotSupported
}
return v, nil
- default:
- return -1, tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
- case *tcpip.SendQueueSizeOption:
+ case tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v := e.connected.SendQueuedSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
- }
- *o = qs
- return nil
-
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- return nil
+ return int(v), nil
- case *tcpip.SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ v := e.connected.SendMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
- return nil
+ return int(v), nil
- case *tcpip.ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
e.Lock()
if e.receiver == nil {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ v := e.receiver.RecvMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
}
- *o = qs
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 8e4f06a22..1aaae8487 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -31,7 +31,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
- "gvisor.dev/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
@@ -40,8 +40,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// SocketOperations is a Unix socket. It is similar to an epsocket, except it
-// is backed by a transport.Endpoint instead of a tcpip.Endpoint.
+// SocketOperations is a Unix socket. It is similar to a netstack socket,
+// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint.
//
// +stateify savable
type SocketOperations struct {
@@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}
@@ -143,7 +143,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -155,19 +155,19 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32,
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
// Ioctl implements fs.FileOperations.Ioctl.
func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- return epsocket.Ioctl(ctx, s.ep, io, args)
+ return netstack.Ioctl(ctx, s.ep, io, args)
}
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
- return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
// Listen implements the linux syscall listen(2) for sockets backed by
@@ -474,13 +474,13 @@ func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
- return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal)
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
}
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
- f, err := epsocket.ConvertShutdown(how)
+ f, err := netstack.ConvertShutdown(how)
if err != nil {
return err
}
@@ -546,7 +546,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -581,7 +581,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 445d25010..72ebf766d 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -31,8 +32,8 @@ go_library(
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
"//pkg/sentry/socket/control",
- "//pkg/sentry/socket/epsocket",
"//pkg/sentry/socket/netlink",
+ "//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/usermem",
],
@@ -44,6 +45,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "strace_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":strace_proto"],
+)
+
go_proto_library(
name = "strace_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go
index 3650fd6e1..5d57b75af 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64.go
@@ -335,4 +335,5 @@ var linuxAMD64 = SyscallMap{
315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index f779186ad..94334f6d2 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
- "gvisor.dev/gvisor/pkg/sentry/socket/epsocket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
@@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */)
+ fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 33a40b9c6..cf2a56bed 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -8,6 +8,8 @@ go_library(
"error.go",
"flags.go",
"linux64.go",
+ "linux64_amd64.go",
+ "linux64_arm64.go",
"sigset.go",
"sys_aio.go",
"sys_capability.go",
@@ -74,6 +76,7 @@ go_library(
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/kernel/shm",
+ "//pkg/sentry/kernel/signalfd",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index ed996ba51..b64c49ff5 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,377 +15,8 @@
// Package linux provides syscall tables for amd64 Linux.
package linux
-import (
- "gvisor.dev/gvisor/pkg/abi"
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserror"
+const (
+ _LINUX_SYSNAME = "Linux"
+ _LINUX_RELEASE = "4.4"
+ _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
)
-
-// AUDIT_ARCH_X86_64 identifies the Linux syscall API on AMD64, and is taken
-// from <linux/audit.h>.
-const _AUDIT_ARCH_X86_64 = 0xc000003e
-
-// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
-// numbers from Linux 4.4.
-var AMD64 = &kernel.SyscallTable{
- OS: abi.Linux,
- Arch: arch.AMD64,
- Version: kernel.Version{
- // Version 4.4 is chosen as a stable, longterm version of Linux, which
- // guides the interface provided by this syscall table. The build
- // version is that for a clean build with default kernel config, at 5
- // minutes after v4.4 was tagged.
- Sysname: "Linux",
- Release: "4.4",
- Version: "#1 SMP Sun Jan 10 15:06:54 PST 2016",
- },
- AuditNumber: _AUDIT_ARCH_X86_64,
- Table: map[uintptr]kernel.Syscall{
- 0: syscalls.Supported("read", Read),
- 1: syscalls.Supported("write", Write),
- 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil),
- 3: syscalls.Supported("close", Close),
- 4: syscalls.Supported("stat", Stat),
- 5: syscalls.Supported("fstat", Fstat),
- 6: syscalls.Supported("lstat", Lstat),
- 7: syscalls.Supported("poll", Poll),
- 8: syscalls.Supported("lseek", Lseek),
- 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil),
- 10: syscalls.Supported("mprotect", Mprotect),
- 11: syscalls.Supported("munmap", Munmap),
- 12: syscalls.Supported("brk", Brk),
- 13: syscalls.Supported("rt_sigaction", RtSigaction),
- 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
- 15: syscalls.Supported("rt_sigreturn", RtSigreturn),
- 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
- 17: syscalls.Supported("pread64", Pread64),
- 18: syscalls.Supported("pwrite64", Pwrite64),
- 19: syscalls.Supported("readv", Readv),
- 20: syscalls.Supported("writev", Writev),
- 21: syscalls.Supported("access", Access),
- 22: syscalls.Supported("pipe", Pipe),
- 23: syscalls.Supported("select", Select),
- 24: syscalls.Supported("sched_yield", SchedYield),
- 25: syscalls.Supported("mremap", Mremap),
- 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
- 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
- 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
- 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
- 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
- 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
- 32: syscalls.Supported("dup", Dup),
- 33: syscalls.Supported("dup2", Dup2),
- 34: syscalls.Supported("pause", Pause),
- 35: syscalls.Supported("nanosleep", Nanosleep),
- 36: syscalls.Supported("getitimer", Getitimer),
- 37: syscalls.Supported("alarm", Alarm),
- 38: syscalls.Supported("setitimer", Setitimer),
- 39: syscalls.Supported("getpid", Getpid),
- 40: syscalls.Supported("sendfile", Sendfile),
- 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
- 42: syscalls.Supported("connect", Connect),
- 43: syscalls.Supported("accept", Accept),
- 44: syscalls.Supported("sendto", SendTo),
- 45: syscalls.Supported("recvfrom", RecvFrom),
- 46: syscalls.Supported("sendmsg", SendMsg),
- 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
- 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
- 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
- 50: syscalls.Supported("listen", Listen),
- 51: syscalls.Supported("getsockname", GetSockName),
- 52: syscalls.Supported("getpeername", GetPeerName),
- 53: syscalls.Supported("socketpair", SocketPair),
- 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
- 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
- 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
- 57: syscalls.Supported("fork", Fork),
- 58: syscalls.Supported("vfork", Vfork),
- 59: syscalls.Supported("execve", Execve),
- 60: syscalls.Supported("exit", Exit),
- 61: syscalls.Supported("wait4", Wait4),
- 62: syscalls.Supported("kill", Kill),
- 63: syscalls.Supported("uname", Uname),
- 64: syscalls.Supported("semget", Semget),
- 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
- 67: syscalls.Supported("shmdt", Shmdt),
- 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
- 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
- 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
- 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
- 76: syscalls.Supported("truncate", Truncate),
- 77: syscalls.Supported("ftruncate", Ftruncate),
- 78: syscalls.Supported("getdents", Getdents),
- 79: syscalls.Supported("getcwd", Getcwd),
- 80: syscalls.Supported("chdir", Chdir),
- 81: syscalls.Supported("fchdir", Fchdir),
- 82: syscalls.Supported("rename", Rename),
- 83: syscalls.Supported("mkdir", Mkdir),
- 84: syscalls.Supported("rmdir", Rmdir),
- 85: syscalls.Supported("creat", Creat),
- 86: syscalls.Supported("link", Link),
- 87: syscalls.Supported("unlink", Unlink),
- 88: syscalls.Supported("symlink", Symlink),
- 89: syscalls.Supported("readlink", Readlink),
- 90: syscalls.Supported("chmod", Chmod),
- 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
- 92: syscalls.Supported("chown", Chown),
- 93: syscalls.Supported("fchown", Fchown),
- 94: syscalls.Supported("lchown", Lchown),
- 95: syscalls.Supported("umask", Umask),
- 96: syscalls.Supported("gettimeofday", Gettimeofday),
- 97: syscalls.Supported("getrlimit", Getrlimit),
- 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
- 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
- 100: syscalls.Supported("times", Times),
- 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
- 102: syscalls.Supported("getuid", Getuid),
- 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
- 104: syscalls.Supported("getgid", Getgid),
- 105: syscalls.Supported("setuid", Setuid),
- 106: syscalls.Supported("setgid", Setgid),
- 107: syscalls.Supported("geteuid", Geteuid),
- 108: syscalls.Supported("getegid", Getegid),
- 109: syscalls.Supported("setpgid", Setpgid),
- 110: syscalls.Supported("getppid", Getppid),
- 111: syscalls.Supported("getpgrp", Getpgrp),
- 112: syscalls.Supported("setsid", Setsid),
- 113: syscalls.Supported("setreuid", Setreuid),
- 114: syscalls.Supported("setregid", Setregid),
- 115: syscalls.Supported("getgroups", Getgroups),
- 116: syscalls.Supported("setgroups", Setgroups),
- 117: syscalls.Supported("setresuid", Setresuid),
- 118: syscalls.Supported("getresuid", Getresuid),
- 119: syscalls.Supported("setresgid", Setresgid),
- 120: syscalls.Supported("getresgid", Getresgid),
- 121: syscalls.Supported("getpgid", Getpgid),
- 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 124: syscalls.Supported("getsid", Getsid),
- 125: syscalls.Supported("capget", Capget),
- 126: syscalls.Supported("capset", Capset),
- 127: syscalls.Supported("rt_sigpending", RtSigpending),
- 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
- 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
- 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
- 131: syscalls.Supported("sigaltstack", Sigaltstack),
- 132: syscalls.Supported("utime", Utime),
- 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil),
- 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil),
- 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
- 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil),
- 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
- 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
- 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
- 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
- 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
- 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
- 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
- 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
- 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
- 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
- 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
- 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
- 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
- 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil),
- 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
- 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil),
- 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
- 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil),
- 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
- 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
- 161: syscalls.Supported("chroot", Chroot),
- 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
- 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
- 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
- 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
- 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
- 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
- 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
- 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
- 170: syscalls.Supported("sethostname", Sethostname),
- 171: syscalls.Supported("setdomainname", Setdomainname),
- 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil),
- 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil),
- 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil),
- 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
- 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
- 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
- 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
- 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
- 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
- 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 186: syscalls.Supported("gettid", Gettid),
- 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341)
- 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 200: syscalls.Supported("tkill", Tkill),
- 201: syscalls.Supported("time", Time),
- 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
- 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
- 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
- 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
- 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
- 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
- 213: syscalls.Supported("epoll_create", EpollCreate),
- 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil),
- 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil),
- 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
- 217: syscalls.Supported("getdents64", Getdents64),
- 218: syscalls.Supported("set_tid_address", SetTidAddress),
- 219: syscalls.Supported("restart_syscall", RestartSyscall),
- 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
- 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
- 222: syscalls.Supported("timer_create", TimerCreate),
- 223: syscalls.Supported("timer_settime", TimerSettime),
- 224: syscalls.Supported("timer_gettime", TimerGettime),
- 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
- 226: syscalls.Supported("timer_delete", TimerDelete),
- 227: syscalls.Supported("clock_settime", ClockSettime),
- 228: syscalls.Supported("clock_gettime", ClockGettime),
- 229: syscalls.Supported("clock_getres", ClockGetres),
- 230: syscalls.Supported("clock_nanosleep", ClockNanosleep),
- 231: syscalls.Supported("exit_group", ExitGroup),
- 232: syscalls.Supported("epoll_wait", EpollWait),
- 233: syscalls.Supported("epoll_ctl", EpollCtl),
- 234: syscalls.Supported("tgkill", Tgkill),
- 235: syscalls.Supported("utimes", Utimes),
- 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil),
- 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
- 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
- 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
- 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
- 247: syscalls.Supported("waitid", Waitid),
- 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
- 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
- 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
- 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
- 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
- 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil),
- 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
- 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
- 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
- 257: syscalls.Supported("openat", Openat),
- 258: syscalls.Supported("mkdirat", Mkdirat),
- 259: syscalls.Supported("mknodat", Mknodat),
- 260: syscalls.Supported("fchownat", Fchownat),
- 261: syscalls.Supported("futimesat", Futimesat),
- 262: syscalls.Supported("fstatat", Fstatat),
- 263: syscalls.Supported("unlinkat", Unlinkat),
- 264: syscalls.Supported("renameat", Renameat),
- 265: syscalls.Supported("linkat", Linkat),
- 266: syscalls.Supported("symlinkat", Symlinkat),
- 267: syscalls.Supported("readlinkat", Readlinkat),
- 268: syscalls.Supported("fchmodat", Fchmodat),
- 269: syscalls.Supported("faccessat", Faccessat),
- 270: syscalls.Supported("pselect", Pselect),
- 271: syscalls.Supported("ppoll", Ppoll),
- 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
- 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
- 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
- 280: syscalls.Supported("utimensat", Utimensat),
- 281: syscalls.Supported("epoll_pwait", EpollPwait),
- 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
- 283: syscalls.Supported("timerfd_create", TimerfdCreate),
- 284: syscalls.Supported("eventfd", Eventfd),
- 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
- 286: syscalls.Supported("timerfd_settime", TimerfdSettime),
- 287: syscalls.Supported("timerfd_gettime", TimerfdGettime),
- 288: syscalls.Supported("accept4", Accept4),
- 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
- 290: syscalls.Supported("eventfd2", Eventfd2),
- 291: syscalls.Supported("epoll_create1", EpollCreate1),
- 292: syscalls.Supported("dup3", Dup3),
- 293: syscalls.Supported("pipe2", Pipe2),
- 294: syscalls.Supported("inotify_init1", InotifyInit1),
- 295: syscalls.Supported("preadv", Preadv),
- 296: syscalls.Supported("pwritev", Pwritev),
- 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
- 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
- 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
- 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 302: syscalls.Supported("prlimit64", Prlimit64),
- 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
- 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
- 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
- 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
- 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
- 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
- 309: syscalls.Supported("getcpu", Getcpu),
- 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
- 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
- 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
- 317: syscalls.Supported("seccomp", Seccomp),
- 318: syscalls.Supported("getrandom", GetRandom),
- 319: syscalls.Supported("memfd_create", MemfdCreate),
- 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
- 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
- 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
- 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
- 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
-
- // Syscalls after 325 are "backports" from versions of Linux after 4.4.
- 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
- 327: syscalls.Supported("preadv2", Preadv2),
- 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
- 332: syscalls.Supported("statx", Statx),
- },
-
- Emulate: map[usermem.Addr]uintptr{
- 0xffffffffff600000: 96, // vsyscall gettimeofday(2)
- 0xffffffffff600400: 201, // vsyscall time(2)
- 0xffffffffff600800: 309, // vsyscall getcpu(2)
- },
- Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
- t.Kernel().EmitUnimplementedEvent(t)
- return 0, syserror.ENOSYS
- },
-}
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
new file mode 100644
index 000000000..e215ac049
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -0,0 +1,386 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
+// numbers from Linux 4.4.
+var AMD64 = &kernel.SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.AMD64,
+ Version: kernel.Version{
+ // Version 4.4 is chosen as a stable, longterm version of Linux, which
+ // guides the interface provided by this syscall table. The build
+ // version is that for a clean build with default kernel config, at 5
+ // minutes after v4.4 was tagged.
+ Sysname: _LINUX_SYSNAME,
+ Release: _LINUX_RELEASE,
+ Version: _LINUX_VERSION,
+ },
+ AuditNumber: linux.AUDIT_ARCH_X86_64,
+ Table: map[uintptr]kernel.Syscall{
+ 0: syscalls.Supported("read", Read),
+ 1: syscalls.Supported("write", Write),
+ 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil),
+ 3: syscalls.Supported("close", Close),
+ 4: syscalls.Supported("stat", Stat),
+ 5: syscalls.Supported("fstat", Fstat),
+ 6: syscalls.Supported("lstat", Lstat),
+ 7: syscalls.Supported("poll", Poll),
+ 8: syscalls.Supported("lseek", Lseek),
+ 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil),
+ 10: syscalls.Supported("mprotect", Mprotect),
+ 11: syscalls.Supported("munmap", Munmap),
+ 12: syscalls.Supported("brk", Brk),
+ 13: syscalls.Supported("rt_sigaction", RtSigaction),
+ 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
+ 15: syscalls.Supported("rt_sigreturn", RtSigreturn),
+ 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
+ 17: syscalls.Supported("pread64", Pread64),
+ 18: syscalls.Supported("pwrite64", Pwrite64),
+ 19: syscalls.Supported("readv", Readv),
+ 20: syscalls.Supported("writev", Writev),
+ 21: syscalls.Supported("access", Access),
+ 22: syscalls.Supported("pipe", Pipe),
+ 23: syscalls.Supported("select", Select),
+ 24: syscalls.Supported("sched_yield", SchedYield),
+ 25: syscalls.Supported("mremap", Mremap),
+ 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
+ 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
+ 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
+ 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
+ 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
+ 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
+ 32: syscalls.Supported("dup", Dup),
+ 33: syscalls.Supported("dup2", Dup2),
+ 34: syscalls.Supported("pause", Pause),
+ 35: syscalls.Supported("nanosleep", Nanosleep),
+ 36: syscalls.Supported("getitimer", Getitimer),
+ 37: syscalls.Supported("alarm", Alarm),
+ 38: syscalls.Supported("setitimer", Setitimer),
+ 39: syscalls.Supported("getpid", Getpid),
+ 40: syscalls.Supported("sendfile", Sendfile),
+ 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
+ 42: syscalls.Supported("connect", Connect),
+ 43: syscalls.Supported("accept", Accept),
+ 44: syscalls.Supported("sendto", SendTo),
+ 45: syscalls.Supported("recvfrom", RecvFrom),
+ 46: syscalls.Supported("sendmsg", SendMsg),
+ 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
+ 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
+ 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
+ 50: syscalls.Supported("listen", Listen),
+ 51: syscalls.Supported("getsockname", GetSockName),
+ 52: syscalls.Supported("getpeername", GetPeerName),
+ 53: syscalls.Supported("socketpair", SocketPair),
+ 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
+ 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
+ 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
+ 57: syscalls.Supported("fork", Fork),
+ 58: syscalls.Supported("vfork", Vfork),
+ 59: syscalls.Supported("execve", Execve),
+ 60: syscalls.Supported("exit", Exit),
+ 61: syscalls.Supported("wait4", Wait4),
+ 62: syscalls.Supported("kill", Kill),
+ 63: syscalls.Supported("uname", Uname),
+ 64: syscalls.Supported("semget", Semget),
+ 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 67: syscalls.Supported("shmdt", Shmdt),
+ 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
+ 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
+ 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
+ 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
+ 76: syscalls.Supported("truncate", Truncate),
+ 77: syscalls.Supported("ftruncate", Ftruncate),
+ 78: syscalls.Supported("getdents", Getdents),
+ 79: syscalls.Supported("getcwd", Getcwd),
+ 80: syscalls.Supported("chdir", Chdir),
+ 81: syscalls.Supported("fchdir", Fchdir),
+ 82: syscalls.Supported("rename", Rename),
+ 83: syscalls.Supported("mkdir", Mkdir),
+ 84: syscalls.Supported("rmdir", Rmdir),
+ 85: syscalls.Supported("creat", Creat),
+ 86: syscalls.Supported("link", Link),
+ 87: syscalls.Supported("unlink", Unlink),
+ 88: syscalls.Supported("symlink", Symlink),
+ 89: syscalls.Supported("readlink", Readlink),
+ 90: syscalls.Supported("chmod", Chmod),
+ 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
+ 92: syscalls.Supported("chown", Chown),
+ 93: syscalls.Supported("fchown", Fchown),
+ 94: syscalls.Supported("lchown", Lchown),
+ 95: syscalls.Supported("umask", Umask),
+ 96: syscalls.Supported("gettimeofday", Gettimeofday),
+ 97: syscalls.Supported("getrlimit", Getrlimit),
+ 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
+ 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
+ 100: syscalls.Supported("times", Times),
+ 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
+ 102: syscalls.Supported("getuid", Getuid),
+ 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
+ 104: syscalls.Supported("getgid", Getgid),
+ 105: syscalls.Supported("setuid", Setuid),
+ 106: syscalls.Supported("setgid", Setgid),
+ 107: syscalls.Supported("geteuid", Geteuid),
+ 108: syscalls.Supported("getegid", Getegid),
+ 109: syscalls.Supported("setpgid", Setpgid),
+ 110: syscalls.Supported("getppid", Getppid),
+ 111: syscalls.Supported("getpgrp", Getpgrp),
+ 112: syscalls.Supported("setsid", Setsid),
+ 113: syscalls.Supported("setreuid", Setreuid),
+ 114: syscalls.Supported("setregid", Setregid),
+ 115: syscalls.Supported("getgroups", Getgroups),
+ 116: syscalls.Supported("setgroups", Setgroups),
+ 117: syscalls.Supported("setresuid", Setresuid),
+ 118: syscalls.Supported("getresuid", Getresuid),
+ 119: syscalls.Supported("setresgid", Setresgid),
+ 120: syscalls.Supported("getresgid", Getresgid),
+ 121: syscalls.Supported("getpgid", Getpgid),
+ 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 124: syscalls.Supported("getsid", Getsid),
+ 125: syscalls.Supported("capget", Capget),
+ 126: syscalls.Supported("capset", Capset),
+ 127: syscalls.Supported("rt_sigpending", RtSigpending),
+ 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
+ 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
+ 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
+ 131: syscalls.Supported("sigaltstack", Sigaltstack),
+ 132: syscalls.Supported("utime", Utime),
+ 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil),
+ 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil),
+ 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil),
+ 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
+ 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
+ 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
+ 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
+ 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
+ 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
+ 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
+ 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
+ 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
+ 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
+ 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
+ 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
+ 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
+ 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil),
+ 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
+ 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil),
+ 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
+ 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil),
+ 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
+ 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
+ 161: syscalls.Supported("chroot", Chroot),
+ 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
+ 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
+ 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
+ 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
+ 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
+ 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
+ 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
+ 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
+ 170: syscalls.Supported("sethostname", Sethostname),
+ 171: syscalls.Supported("setdomainname", Setdomainname),
+ 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil),
+ 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil),
+ 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil),
+ 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
+ 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
+ 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
+ 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
+ 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 186: syscalls.Supported("gettid", Gettid),
+ 187: syscalls.Supported("readahead", Readahead),
+ 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 200: syscalls.Supported("tkill", Tkill),
+ 201: syscalls.Supported("time", Time),
+ 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
+ 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
+ 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
+ 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
+ 213: syscalls.Supported("epoll_create", EpollCreate),
+ 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil),
+ 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil),
+ 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 217: syscalls.Supported("getdents64", Getdents64),
+ 218: syscalls.Supported("set_tid_address", SetTidAddress),
+ 219: syscalls.Supported("restart_syscall", RestartSyscall),
+ 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
+ 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
+ 222: syscalls.Supported("timer_create", TimerCreate),
+ 223: syscalls.Supported("timer_settime", TimerSettime),
+ 224: syscalls.Supported("timer_gettime", TimerGettime),
+ 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
+ 226: syscalls.Supported("timer_delete", TimerDelete),
+ 227: syscalls.Supported("clock_settime", ClockSettime),
+ 228: syscalls.Supported("clock_gettime", ClockGettime),
+ 229: syscalls.Supported("clock_getres", ClockGetres),
+ 230: syscalls.Supported("clock_nanosleep", ClockNanosleep),
+ 231: syscalls.Supported("exit_group", ExitGroup),
+ 232: syscalls.Supported("epoll_wait", EpollWait),
+ 233: syscalls.Supported("epoll_ctl", EpollCtl),
+ 234: syscalls.Supported("tgkill", Tgkill),
+ 235: syscalls.Supported("utimes", Utimes),
+ 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil),
+ 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
+ 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
+ 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
+ 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
+ 247: syscalls.Supported("waitid", Waitid),
+ 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
+ 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
+ 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
+ 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil),
+ 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
+ 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
+ 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
+ 257: syscalls.Supported("openat", Openat),
+ 258: syscalls.Supported("mkdirat", Mkdirat),
+ 259: syscalls.Supported("mknodat", Mknodat),
+ 260: syscalls.Supported("fchownat", Fchownat),
+ 261: syscalls.Supported("futimesat", Futimesat),
+ 262: syscalls.Supported("fstatat", Fstatat),
+ 263: syscalls.Supported("unlinkat", Unlinkat),
+ 264: syscalls.Supported("renameat", Renameat),
+ 265: syscalls.Supported("linkat", Linkat),
+ 266: syscalls.Supported("symlinkat", Symlinkat),
+ 267: syscalls.Supported("readlinkat", Readlinkat),
+ 268: syscalls.Supported("fchmodat", Fchmodat),
+ 269: syscalls.Supported("faccessat", Faccessat),
+ 270: syscalls.Supported("pselect", Pselect),
+ 271: syscalls.Supported("ppoll", Ppoll),
+ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
+ 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 275: syscalls.Supported("splice", Splice),
+ 276: syscalls.Supported("tee", Tee),
+ 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
+ 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
+ 280: syscalls.Supported("utimensat", Utimensat),
+ 281: syscalls.Supported("epoll_pwait", EpollPwait),
+ 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 283: syscalls.Supported("timerfd_create", TimerfdCreate),
+ 284: syscalls.Supported("eventfd", Eventfd),
+ 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
+ 286: syscalls.Supported("timerfd_settime", TimerfdSettime),
+ 287: syscalls.Supported("timerfd_gettime", TimerfdGettime),
+ 288: syscalls.Supported("accept4", Accept4),
+ 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 290: syscalls.Supported("eventfd2", Eventfd2),
+ 291: syscalls.Supported("epoll_create1", EpollCreate1),
+ 292: syscalls.Supported("dup3", Dup3),
+ 293: syscalls.Supported("pipe2", Pipe2),
+ 294: syscalls.Supported("inotify_init1", InotifyInit1),
+ 295: syscalls.Supported("preadv", Preadv),
+ 296: syscalls.Supported("pwritev", Pwritev),
+ 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
+ 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
+ 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
+ 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 302: syscalls.Supported("prlimit64", Prlimit64),
+ 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
+ 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
+ 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
+ 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
+ 309: syscalls.Supported("getcpu", Getcpu),
+ 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
+ 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
+ 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 317: syscalls.Supported("seccomp", Seccomp),
+ 318: syscalls.Supported("getrandom", GetRandom),
+ 319: syscalls.Supported("memfd_create", MemfdCreate),
+ 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
+ 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
+ 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
+ 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
+ 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+
+ // Syscalls after 325 are "backports" from versions of Linux after 4.4.
+ 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 327: syscalls.Supported("preadv2", Preadv2),
+ 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 332: syscalls.Supported("statx", Statx),
+ },
+
+ Emulate: map[usermem.Addr]uintptr{
+ 0xffffffffff600000: 96, // vsyscall gettimeofday(2)
+ 0xffffffffff600400: 201, // vsyscall time(2)
+ 0xffffffffff600800: 309, // vsyscall getcpu(2)
+ },
+ Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, syserror.ENOSYS
+ },
+}
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
new file mode 100644
index 000000000..1d3b63020
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -0,0 +1,313 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ARM64 is a table of Linux arm64 syscall API with the corresponding syscall
+// numbers from Linux 4.4.
+var ARM64 = &kernel.SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.ARM64,
+ Version: kernel.Version{
+ Sysname: _LINUX_SYSNAME,
+ Release: _LINUX_RELEASE,
+ Version: _LINUX_VERSION,
+ },
+ AuditNumber: linux.AUDIT_ARCH_AARCH64,
+ Table: map[uintptr]kernel.Syscall{
+ 0: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 1: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 5: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 8: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 12: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 13: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 14: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 15: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 16: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 17: syscalls.Supported("getcwd", Getcwd),
+ 18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
+ 19: syscalls.Supported("eventfd2", Eventfd2),
+ 20: syscalls.Supported("epoll_create1", EpollCreate1),
+ 21: syscalls.Supported("epoll_ctl", EpollCtl),
+ 22: syscalls.Supported("epoll_pwait", EpollPwait),
+ 23: syscalls.Supported("dup", Dup),
+ 24: syscalls.Supported("dup3", Dup3),
+ 26: syscalls.Supported("inotify_init1", InotifyInit1),
+ 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
+ 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
+ 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
+ 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 32: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
+ 33: syscalls.Supported("mknodat", Mknodat),
+ 34: syscalls.Supported("mkdirat", Mkdirat),
+ 35: syscalls.Supported("unlinkat", Unlinkat),
+ 36: syscalls.Supported("symlinkat", Symlinkat),
+ 37: syscalls.Supported("linkat", Linkat),
+ 38: syscalls.Supported("renameat", Renameat),
+ 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
+ 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
+ 41: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
+ 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
+ 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
+ 46: syscalls.Supported("ftruncate", Ftruncate),
+ 47: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
+ 48: syscalls.Supported("faccessat", Faccessat),
+ 49: syscalls.Supported("chdir", Chdir),
+ 50: syscalls.Supported("fchdir", Fchdir),
+ 51: syscalls.Supported("chroot", Chroot),
+ 52: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
+ 53: syscalls.Supported("fchmodat", Fchmodat),
+ 54: syscalls.Supported("fchownat", Fchownat),
+ 55: syscalls.Supported("fchown", Fchown),
+ 56: syscalls.Supported("openat", Openat),
+ 57: syscalls.Supported("close", Close),
+ 58: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
+ 59: syscalls.Supported("pipe2", Pipe2),
+ 60: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
+ 61: syscalls.Supported("getdents64", Getdents64),
+ 62: syscalls.Supported("lseek", Lseek),
+ 63: syscalls.Supported("read", Read),
+ 64: syscalls.Supported("write", Write),
+ 65: syscalls.Supported("readv", Readv),
+ 66: syscalls.Supported("writev", Writev),
+ 67: syscalls.Supported("pread64", Pread64),
+ 68: syscalls.Supported("pwrite64", Pwrite64),
+ 69: syscalls.Supported("preadv", Preadv),
+ 70: syscalls.Supported("pwritev", Pwritev),
+ 71: syscalls.Supported("sendfile", Sendfile),
+ 72: syscalls.Supported("pselect", Pselect),
+ 73: syscalls.Supported("ppoll", Ppoll),
+ 74: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
+ 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 77: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 78: syscalls.Supported("readlinkat", Readlinkat),
+ 80: syscalls.Supported("fstat", Fstat),
+ 81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
+ 82: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
+ 83: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
+ 84: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
+ 85: syscalls.Supported("timerfd_create", TimerfdCreate),
+ 86: syscalls.Supported("timerfd_settime", TimerfdSettime),
+ 87: syscalls.Supported("timerfd_gettime", TimerfdGettime),
+ 88: syscalls.Supported("utimensat", Utimensat),
+ 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
+ 90: syscalls.Supported("capget", Capget),
+ 91: syscalls.Supported("capset", Capset),
+ 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 93: syscalls.Supported("exit", Exit),
+ 94: syscalls.Supported("exit_group", ExitGroup),
+ 95: syscalls.Supported("waitid", Waitid),
+ 96: syscalls.Supported("set_tid_address", SetTidAddress),
+ 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
+ 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
+ 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 101: syscalls.Supported("nanosleep", Nanosleep),
+ 102: syscalls.Supported("getitimer", Getitimer),
+ 103: syscalls.Supported("setitimer", Setitimer),
+ 104: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
+ 105: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
+ 106: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
+ 107: syscalls.Supported("timer_create", TimerCreate),
+ 108: syscalls.Supported("timer_gettime", TimerGettime),
+ 109: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
+ 110: syscalls.Supported("timer_settime", TimerSettime),
+ 111: syscalls.Supported("timer_delete", TimerDelete),
+ 112: syscalls.Supported("clock_settime", ClockSettime),
+ 113: syscalls.Supported("clock_gettime", ClockGettime),
+ 114: syscalls.Supported("clock_getres", ClockGetres),
+ 115: syscalls.Supported("clock_nanosleep", ClockNanosleep),
+ 116: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
+ 117: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
+ 118: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
+ 119: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
+ 120: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
+ 121: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
+ 122: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
+ 123: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
+ 124: syscalls.Supported("sched_yield", SchedYield),
+ 125: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
+ 126: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
+ 127: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
+ 128: syscalls.Supported("restart_syscall", RestartSyscall),
+ 129: syscalls.Supported("kill", Kill),
+ 130: syscalls.Supported("tkill", Tkill),
+ 131: syscalls.Supported("tgkill", Tgkill),
+ 132: syscalls.Supported("sigaltstack", Sigaltstack),
+ 133: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
+ 134: syscalls.Supported("rt_sigaction", RtSigaction),
+ 135: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
+ 136: syscalls.Supported("rt_sigpending", RtSigpending),
+ 137: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
+ 138: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
+ 139: syscalls.Supported("rt_sigreturn", RtSigreturn),
+ 140: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
+ 141: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
+ 142: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
+ 143: syscalls.Supported("setregid", Setregid),
+ 144: syscalls.Supported("setgid", Setgid),
+ 145: syscalls.Supported("setreuid", Setreuid),
+ 146: syscalls.Supported("setuid", Setuid),
+ 147: syscalls.Supported("setresuid", Setresuid),
+ 148: syscalls.Supported("getresuid", Getresuid),
+ 149: syscalls.Supported("setresgid", Setresgid),
+ 150: syscalls.Supported("getresgid", Getresgid),
+ 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 153: syscalls.Supported("times", Times),
+ 154: syscalls.Supported("setpgid", Setpgid),
+ 155: syscalls.Supported("getpgid", Getpgid),
+ 156: syscalls.Supported("getsid", Getsid),
+ 157: syscalls.Supported("setsid", Setsid),
+ 158: syscalls.Supported("getgroups", Getgroups),
+ 159: syscalls.Supported("setgroups", Setgroups),
+ 160: syscalls.Supported("uname", Uname),
+ 161: syscalls.Supported("sethostname", Sethostname),
+ 162: syscalls.Supported("setdomainname", Setdomainname),
+ 163: syscalls.Supported("getrlimit", Getrlimit),
+ 164: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
+ 165: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
+ 166: syscalls.Supported("umask", Umask),
+ 167: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
+ 168: syscalls.Supported("getcpu", Getcpu),
+ 169: syscalls.Supported("gettimeofday", Gettimeofday),
+ 170: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
+ 171: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
+ 172: syscalls.Supported("getpid", Getpid),
+ 173: syscalls.Supported("getppid", Getppid),
+ 174: syscalls.Supported("getuid", Getuid),
+ 175: syscalls.Supported("geteuid", Geteuid),
+ 176: syscalls.Supported("getgid", Getgid),
+ 177: syscalls.Supported("getegid", Getegid),
+ 178: syscalls.Supported("gettid", Gettid),
+ 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
+ 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 190: syscalls.Supported("semget", Semget),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
+ 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
+ 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
+ 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
+ 196: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
+ 197: syscalls.Supported("shmdt", Shmdt),
+ 198: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
+ 199: syscalls.Supported("socketpair", SocketPair),
+ 200: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
+ 201: syscalls.Supported("listen", Listen),
+ 202: syscalls.Supported("accept", Accept),
+ 203: syscalls.Supported("connect", Connect),
+ 204: syscalls.Supported("getsockname", GetSockName),
+ 205: syscalls.Supported("getpeername", GetPeerName),
+ 206: syscalls.Supported("sendto", SendTo),
+ 207: syscalls.Supported("recvfrom", RecvFrom),
+ 208: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
+ 209: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
+ 210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
+ 211: syscalls.Supported("sendmsg", SendMsg),
+ 212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
+ 213: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341)
+ 214: syscalls.Supported("brk", Brk),
+ 215: syscalls.Supported("munmap", Munmap),
+ 216: syscalls.Supported("mremap", Mremap),
+ 217: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
+ 218: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
+ 219: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
+ 220: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
+ 221: syscalls.Supported("execve", Execve),
+ 224: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
+ 225: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
+ 226: syscalls.Supported("mprotect", Mprotect),
+ 227: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
+ 228: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 229: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 230: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
+ 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
+ 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
+ 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
+ 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
+ 238: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
+ 239: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
+ 240: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
+ 241: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
+ 242: syscalls.Supported("accept4", Accept4),
+ 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
+ 260: syscalls.Supported("wait4", Wait4),
+ 261: syscalls.Supported("prlimit64", Prlimit64),
+ 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 264: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 265: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
+ 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
+ 268: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
+ 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
+ 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
+ 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
+ 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 277: syscalls.Supported("seccomp", Seccomp),
+ 278: syscalls.Supported("getrandom", GetRandom),
+ 279: syscalls.Supported("memfd_create", MemfdCreate),
+ 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
+ 281: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
+ 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
+ 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 286: syscalls.Supported("preadv2", Preadv2),
+ 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 291: syscalls.Supported("statx", Statx),
+ },
+ Emulate: map[usermem.Addr]uintptr{},
+
+ Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, syserror.ENOSYS
+ },
+}
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 2e00a91ce..b9a8e3e21 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -1423,9 +1423,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
if err != nil {
return err
}
- if dirPath {
- return syserror.ENOENT
- }
return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
if !fs.IsDir(d.Inode.StableAttr) {
@@ -1436,7 +1433,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return err
}
- return d.Remove(t, root, name)
+ return d.Remove(t, root, name, dirPath)
})
}
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 3ab54271c..cd31e0649 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -72,6 +72,39 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
}
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
+
// Pread64 implements linux syscall pread64(2).
func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 0104a94c0..fb6efd5d8 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -20,7 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?")
return 0, nil, syserror.EINTR
}
+
+// sharedSignalfd is shared between the two calls.
+func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ // Copy in the signal mask.
+ mask, err := copyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Always check for valid flags, even if not creating.
+ if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is this a change to an existing signalfd?
+ //
+ // The spec indicates that this should adjust the mask.
+ if fd != -1 {
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is this a signalfd?
+ if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok {
+ s.SetMask(mask)
+ return 0, nil, nil
+ }
+
+ // Not a signalfd.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create a new file.
+ file, err := signalfd.New(t, mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ // Set appropriate flags.
+ file.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.SFD_NONBLOCK != 0,
+ })
+
+ // Create a new descriptor.
+ fd, err = t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.SFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Done.
+ return uintptr(fd), nil, nil
+}
+
+// Signalfd implements the linux syscall signalfd(2).
+func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, 0)
+}
+
+// Signalfd4 implements the linux syscall signalfd4(2).
+func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ flags := args[3].Int()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, flags)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 3bac4d90d..b5a72ce63 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- if optLen <= 0 {
+ if optLen < 0 {
return 0, nil, syserror.EINVAL
}
if optLen > maxOptLen {
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 17e3dde1f..dd3a5807f 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
total int64
n int64
err error
- ch chan struct{}
- inW bool
- outW bool
+ inCh chan struct{}
+ outCh chan struct{}
)
for opts.Length > 0 {
n, err = fs.Splice(t, outFile, inFile, opts)
@@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
break
}
- // Are we a registered waiter?
- if ch == nil {
- ch = make(chan struct{}, 1)
- }
- if !inW && !inFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- inFile.EventRegister(&w, EventMaskRead)
- defer inFile.EventUnregister(&w)
- inW = true // Registered.
- } else if !outW && !outFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- outFile.EventRegister(&w, EventMaskWrite)
- defer outFile.EventUnregister(&w)
- outW = true // Registered.
- }
-
- // Was anything registered? If no, everything is non-blocking.
- if !inW && !outW {
- break
- }
-
- if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) {
- // Something became ready, try again without blocking.
- continue
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the splice operation.
+ if inFile.Readiness(EventMaskRead) == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, EventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(inCh); err != nil {
+ break
+ }
}
-
- // Block until there's data.
- if err = t.Block(ch); err != nil {
- break
+ if outFile.Readiness(EventMaskWrite) == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(outCh); err != nil {
+ break
+ }
}
}
@@ -91,22 +88,29 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Get files.
+ inFile := t.GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ if !inFile.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
outFile := t.GetFile(outFD)
if outFile == nil {
return 0, nil, syserror.EBADF
}
defer outFile.DecRef()
- inFile := t.GetFile(inFD)
- if inFile == nil {
+ if !outFile.Flags().Write {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
- // Verify that the outfile Append flag is not set. Note that fs.Splice
- // itself validates that the output file is writable.
+ // Verify that the outfile Append flag is not set.
if outFile.Flags().Append {
- return 0, nil, syserror.EBADF
+ return 0, nil, syserror.EINVAL
}
// Verify that we have a regular infile. This is a requirement; the
@@ -142,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Length: count,
SrcOffset: true,
SrcStart: offset,
- }, false)
+ }, outFile.Flags().NonBlocking)
// Copy out the new offset.
if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
@@ -152,12 +156,17 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Send data using splice.
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
- }, false)
+ }, outFile.Flags().NonBlocking)
+ }
+
+ // Sendfile can't lose any data because inFD is always a regual file.
+ if n != 0 {
+ err = nil
}
// We can only pass a single file to handleIOError, so pick inFile
// arbitrarily. This is used only for debugging purposes.
- return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "sendfile", inFile)
+ return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile)
}
// Splice implements splice(2).
@@ -174,12 +183,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful. Note that unlike in Linux, this
- // flag is applied consistently. We will have either fully blocking or
- // non-blocking behavior below, regardless of the underlying files
- // being spliced to. It's unclear if this is a bug or not yet.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -193,6 +196,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
defer inFile.DecRef()
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Construct our options.
//
// Note that exactly one of the underlying buffers must be a pipe. We
@@ -240,17 +250,17 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inOffset != 0 || outOffset != 0 {
return 0, nil, syserror.ESPIPE
}
- default:
- return 0, nil, syserror.EINVAL
- }
- // We may not refer to the same pipe; otherwise it's a continuous loop.
- if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ // We may not refer to the same pipe; otherwise it's a continuous loop.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+ default:
return 0, nil, syserror.EINVAL
}
// Splice data.
- n, err := doSplice(t, outFile, inFile, opts, nonBlocking)
+ n, err := doSplice(t, outFile, inFile, opts, nonBlock)
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile)
@@ -268,9 +278,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -294,12 +301,20 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
+ // The operation is non-blocking if anything is non-blocking.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Splice data.
n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
Dup: true,
- }, nonBlocking)
+ }, nonBlock)
+
+ // Tee doesn't change a state of inFD, so it can't lose any data.
+ if n != 0 {
+ err = nil
+ }
// See above; inFile is chosen arbitrarily here.
- return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile)
+ return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile)
}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 4b3f043a2..b887fa9d7 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -15,6 +15,7 @@
package linux
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use
timer.Destroy()
- var remaining time.Duration
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- remaining = 0
- } else {
- remaining = dur - after.Sub(start)
+ switch err {
+ case syserror.ETIMEDOUT:
+ // Slept for entire timeout.
+ return nil
+ case syserror.ErrInterrupted:
+ // Interrupted.
+ remaining := dur - after.Sub(start)
if remaining < 0 {
remaining = time.Duration(0)
}
- }
- // Copy out remaining time.
- if err != nil && rem != usermem.Addr(0) {
- timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
- if err := copyTimespecOut(t, rem, &timeleft); err != nil {
- return err
+ // Copy out remaining time.
+ if rem != 0 {
+ timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
+ if err := copyTimespecOut(t, rem, &timeleft); err != nil {
+ return err
+ }
}
- }
-
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- return nil
- }
- // If interrupted, arrange for a restart with the remaining duration.
- if err == syserror.ErrInterrupted {
+ // Arrange for a restart with the remaining duration.
t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
c: c,
duration: remaining,
rem: rem,
})
return kernel.ERESTART_RESTARTBLOCK
+ default:
+ panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
-
- return err
}
// Nanosleep implements linux syscall Nanosleep(2).
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 271ace08e..748e8dd8d 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- name, err := t.CopyInString(nameAddr, int(size))
- if err != nil {
+ name := make([]byte, size)
+ if _, err := t.CopyInBytes(nameAddr, name); err != nil {
return 0, nil, err
}
- utsns.SetHostName(name)
+ utsns.SetHostName(string(name))
return 0, nil, nil
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 8aa6a3017..beb43ba13 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
index b69603da3..fc7614fff 100644
--- a/pkg/sentry/unimpl/BUILD
+++ b/pkg/sentry/unimpl/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -10,6 +11,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "unimplemented_syscall_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":unimplemented_syscall_proto"],
+)
+
go_proto_library(
name = "unimplemented_syscall_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto",
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index f4326706a..d6ef644d8 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -277,8 +277,3 @@ func TotalMemory(memSize, used uint64) uint64 {
}
return memSize
}
-
-// IncrementalMappedAccounting controls whether host mapped memory is accounted
-// incrementally during map translation. This may be modified during early
-// initialization, and is read-only afterward.
-var IncrementalMappedAccounting = false
diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD
index a5b4206bb..cc5d25762 100644
--- a/pkg/sentry/usermem/BUILD
+++ b/pkg/sentry/usermem/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "addr_range",
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
index 6eced660a..7b1f312b1 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/sentry/usermem/usermem.go
@@ -16,6 +16,7 @@
package usermem
import (
+ "bytes"
"errors"
"io"
"strconv"
@@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
// Look for the terminating zero byte, which may have occurred before
// hitting err.
- for i, c := range buf[done : done+n] {
- if c == 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
- }
+ if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
+ return stringFromImmutableBytes(buf[:done+i]), nil
}
+
done += n
if err != nil {
return stringFromImmutableBytes(buf[:done]), err
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 0f247bf77..eff4b44f6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 86bde7fb3..3a9665800 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -102,7 +102,7 @@ type FileDescriptionImpl interface {
// OnClose is called when a file descriptor representing the
// FileDescription is closed. Note that returning a non-nil error does not
// prevent the file descriptor from being closed.
- OnClose() error
+ OnClose(ctx context.Context) error
// StatusFlags returns file description status flags, as for
// fcntl(F_GETFL).
@@ -180,7 +180,7 @@ type FileDescriptionImpl interface {
// ConfigureMMap mutates opts to implement mmap(2) for the file. Most
// implementations that support memory mapping can call
// GenericConfigureMMap with the appropriate memmap.Mappable.
- ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error
+ ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error
// Ioctl implements the ioctl(2) syscall.
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
@@ -199,8 +199,11 @@ type Dirent struct {
// Ino is the inode number.
Ino uint64
- // Off is this Dirent's offset.
- Off int64
+ // NextOff is the offset of the *next* Dirent in the directory; that is,
+ // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will
+ // cause the next call to FileDescription.IterDirents() to yield the next
+ // Dirent. (The offset of the first Dirent in a directory is always 0.)
+ NextOff int64
}
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index ba230da72..4fbad7840 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -45,7 +45,7 @@ type FileDescriptionDefaultImpl struct{}
// OnClose implements FileDescriptionImpl.OnClose analogously to
// file_operations::flush == NULL in Linux.
-func (FileDescriptionDefaultImpl) OnClose() error {
+func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error {
return nil
}
@@ -117,7 +117,7 @@ func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error {
// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to
// file_operations::mmap == NULL in Linux.
-func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return syserror.ENODEV
}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 187e5410c..3aa73d911 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -31,14 +31,14 @@ type GetDentryOptions struct {
// FilesystemImpl.MkdirAt().
type MkdirOptions struct {
// Mode is the file mode bits for the created directory.
- Mode uint16
+ Mode linux.FileMode
}
// MknodOptions contains options to VirtualFilesystem.MknodAt() and
// FilesystemImpl.MknodAt().
type MknodOptions struct {
// Mode is the file type and mode bits for the created file.
- Mode uint16
+ Mode linux.FileMode
// If Mode specifies a character or block device special file, DevMajor and
// DevMinor are the major and minor device numbers for the created device.
@@ -61,7 +61,7 @@ type OpenOptions struct {
// If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the
// created file.
- Mode uint16
+ Mode linux.FileMode
}
// ReadOptions contains options to FileDescription.PRead(),
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index 00665c939..a23c86fb1 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -6,6 +7,7 @@ go_library(
name = "sleep",
srcs = [
"commit_amd64.s",
+ "commit_arm64.s",
"commit_asm.go",
"commit_noasm.go",
"sleep_unsafe.go",
diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s
new file mode 100644
index 000000000..d0ef15b20
--- /dev/null
+++ b/pkg/sleep/commit_arm64.s
@@ -0,0 +1,38 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+#define preparingG 1
+
+// See commit_noasm.go for a description of commitSleep.
+//
+// func commitSleep(g uintptr, waitingG *uintptr) bool
+TEXT ·commitSleep(SB),NOSPLIT,$0-24
+ MOVD waitingG+8(FP), R0
+ MOVD $preparingG, R1
+ MOVD G+0(FP), R2
+
+ // Store the G in waitingG if it's still preparingG. If it's anything
+ // else it means a waker has aborted the sleep.
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
diff --git a/pkg/sleep/commit_asm.go b/pkg/sleep/commit_asm.go
index 35e2cc337..75728a97d 100644
--- a/pkg/sleep/commit_asm.go
+++ b/pkg/sleep/commit_asm.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64
+// +build amd64 arm64
package sleep
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
index 686b1da3d..3af447fb9 100644
--- a/pkg/sleep/commit_noasm.go
+++ b/pkg/sleep/commit_noasm.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build !race
-// +build !amd64
+// +build !amd64,!arm64
package sleep
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index c0f3c658d..329904457 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index e70f4a79f..8a865d229 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
index b149f9e02..bd3f9fd28 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/syserror/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index df37c7d5a..3c2b2b5ea 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,6 +1,7 @@
-package(licenses = ["notice"])
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+package(licenses = ["notice"])
go_library(
name = "tcpip",
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 0d2637ee4..78df5a0b1 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 672f026b2..8ced960bb 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) {
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ })
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
return nil, err
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index 3301967fb..d6c31bfa2 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,6 +1,7 @@
-package(licenses = ["notice"])
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+package(licenses = ["notice"])
go_library(
name = "buffer",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index afcabd51d..096ad71ab 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -586,3 +586,103 @@ func Payload(want []byte) TransportChecker {
}
}
}
+
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
+// potentially additional ICMPv4 header fields.
+func ICMPv4(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
+ }
+
+ icmp := header.ICMPv4(last.Payload())
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
+func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
+func ICMPv4Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
+// potentially additional ICMPv6 header fields.
+func ICMPv6(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
+ }
+
+ icmp := header.ICMPv6(last.Payload())
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
+func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
+func ICMPv6Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index 29b30be9c..0c5c20cea 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 76ef02f13..a255231a3 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,6 +1,7 @@
-package(licenses = ["notice"])
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+package(licenses = ["notice"])
go_library(
name = "header",
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index c52c0d851..0cac6c0a5 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
@@ -25,13 +26,29 @@ type ICMPv4 []byte
const (
// ICMPv4PayloadOffset defines the start of ICMP payload.
- ICMPv4PayloadOffset = 4
+ ICMPv4PayloadOffset = 8
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
+
+ // icmpv4ChecksumOffset is the offset of the checksum field
+ // in an ICMPv4 message.
+ icmpv4ChecksumOffset = 2
+
+ // icmpv4MTUOffset is the offset of the MTU field
+ // in a ICMPv4FragmentationNeeded message.
+ icmpv4MTUOffset = 6
+
+ // icmpv4IdentOffset is the offset of the ident field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4IdentOffset = 4
+
+ // icmpv4SequenceOffset is the offset of the sequence field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4SequenceOffset = 6
)
// ICMPv4Type is the ICMP type field described in RFC 792.
@@ -72,12 +89,12 @@ func (b ICMPv4) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
- return binary.BigEndian.Uint16(b[2:])
+ return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv4) SetChecksum(checksum uint16) {
- binary.BigEndian.PutUint16(b[2:], checksum)
+ binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum)
}
// SourcePort implements Transport.SourcePort.
@@ -102,3 +119,51 @@ func (ICMPv4) SetDestinationPort(uint16) {
func (b ICMPv4) Payload() []byte {
return b[ICMPv4PayloadOffset:]
}
+
+// MTU retrieves the MTU field from an ICMPv4 message.
+func (b ICMPv4) MTU() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv4 message.
+func (b ICMPv4) SetMTU(mtu uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv4 message.
+func (b ICMPv4) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv4 message.
+func (b ICMPv4) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv4 message.
+func (b ICMPv4) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv4 message.
+func (b ICMPv4) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence)
+}
+
+// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
+// and payload.
+func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := uint16(0)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 3cc57e234..1125a7d14 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
@@ -25,14 +26,18 @@ type ICMPv6 []byte
const (
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
- ICMPv6MinimumSize = 4
+ ICMPv6MinimumSize = 8
+
+ // ICMPv6PayloadOffset is the offset of the payload in an
+ // ICMP packet.
+ ICMPv6PayloadOffset = 8
// ICMPv6ProtocolNumber is the ICMP transport protocol number.
ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
- ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16
// ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
ICMPv6NeighborAdvertSize = 32
@@ -42,11 +47,27 @@ const (
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
- ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
// ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
// packet-too-big packet.
- ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
+
+ // icmpv6ChecksumOffset is the offset of the checksum field
+ // in an ICMPv6 message.
+ icmpv6ChecksumOffset = 2
+
+ // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
+ // PacketTooBig message.
+ icmpv6MTUOffset = 4
+
+ // icmpv6IdentOffset is the offset of the ident field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6IdentOffset = 4
+
+ // icmpv6SequenceOffset is the offset of the sequence field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6SequenceOffset = 6
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
@@ -89,12 +110,12 @@ func (b ICMPv6) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
- return binary.BigEndian.Uint16(b[2:])
+ return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:])
}
// SetChecksum calculates and sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(checksum uint16) {
- binary.BigEndian.PutUint16(b[2:], checksum)
+ binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum)
}
// SourcePort implements Transport.SourcePort.
@@ -115,7 +136,60 @@ func (ICMPv6) SetSourcePort(uint16) {
func (ICMPv6) SetDestinationPort(uint16) {
}
+// MTU retrieves the MTU field from an ICMPv6 message.
+func (b ICMPv6) MTU() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv6 message.
+func (b ICMPv6) SetMTU(mtu uint32) {
+ binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv6 message.
+func (b ICMPv6) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv6 message.
+func (b ICMPv6) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv6 message.
+func (b ICMPv6) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv6 message.
+func (b ICMPv6) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
+}
+
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
- return b[ICMPv6MinimumSize:]
+ return b[ICMPv6PayloadOffset:]
+}
+
+// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header,
+// IPv6 src/dst addresses and the payload.
+func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := Checksum([]byte(src), 0)
+ xsum = Checksum([]byte(dst), xsum)
+ var upperLayerLength [4]byte
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
+ xsum = Checksum(upperLayerLength[:], xsum)
+ xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 17fc9c68e..e5360e7c1 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -21,16 +21,18 @@ import (
)
const (
- versIHL = 0
- tos = 1
- totalLen = 2
- id = 4
- flagsFO = 6
- ttl = 8
- protocol = 9
- checksum = 10
- srcAddr = 12
- dstAddr = 16
+ versIHL = 0
+ tos = 1
+ // IPv4TotalLenOffset is the offset of the total length field in the
+ // IPv4 header.
+ IPv4TotalLenOffset = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
@@ -103,6 +105,11 @@ const (
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+
+ // IPv4MinimumProcessableDatagramSize is the minimum size of an IP
+ // packet that every IPv4 capable host must be able to
+ // process/reassemble.
+ IPv4MinimumProcessableDatagramSize = 576
)
// Flags that may be set in an IPv4 packet.
@@ -163,7 +170,7 @@ func (b IPv4) FragmentOffset() uint16 {
// TotalLength returns the "total length" field of the ipv4 header.
func (b IPv4) TotalLength() uint16 {
- return binary.BigEndian.Uint16(b[totalLen:])
+ return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
}
// Checksum returns the checksum field of the ipv4 header.
@@ -209,7 +216,7 @@ func (b IPv4) SetTOS(v uint8, _ uint32) {
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
- binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+ binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
// SetChecksum sets the checksum field of the ipv4 header.
@@ -265,7 +272,7 @@ func (b IPv4) Encode(i *IPv4Fields) {
// packets are produced.
func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
b.SetTotalLength(totalLength)
- checksum := Checksum(b[totalLen:totalLen+2], partialChecksum)
+ checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum)
b.SetChecksum(^checksum)
}
@@ -277,7 +284,7 @@ func (b IPv4) IsValid(pktSize int) bool {
hlen := int(b.HeaderLength())
tlen := int(b.TotalLength())
- if hlen > tlen || tlen > pktSize {
+ if hlen < IPv4MinimumSize || hlen > tlen || tlen > pktSize {
return false
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 31be42ce0..9d3abc0e4 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -22,12 +22,14 @@ import (
)
const (
- versTCFL = 0
- payloadLen = 4
- nextHdr = 6
- hopLimit = 7
- v6SrcAddr = 8
- v6DstAddr = 24
+ versTCFL = 0
+ // IPv6PayloadLenOffset is the offset of the PayloadLength field in
+ // IPv6 header.
+ IPv6PayloadLenOffset = 4
+ nextHdr = 6
+ hopLimit = 7
+ v6SrcAddr = 8
+ v6DstAddr = v6SrcAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -74,6 +76,13 @@ const (
// IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
+ // IPv6AllNodesMulticastAddress is a link-local multicast group that
+ // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all nodes on a link.
+ //
+ // The address is ff02::1.
+ IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
@@ -94,7 +103,7 @@ var IPv6EmptySubnet = func() tcpip.Subnet {
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
- return binary.BigEndian.Uint16(b[payloadLen:])
+ return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
// HopLimit returns the value of the "hop limit" field of the ipv6 header.
@@ -119,13 +128,13 @@ func (b IPv6) Payload() []byte {
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
- return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
- return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
@@ -148,18 +157,18 @@ func (b IPv6) SetTOS(t uint8, l uint32) {
// SetPayloadLength sets the "payload length" field of the ipv6 header.
func (b IPv6) SetPayloadLength(payloadLength uint16) {
- binary.BigEndian.PutUint16(b[payloadLen:], payloadLength)
+ binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
}
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+ copy(b[v6SrcAddr:][:IPv6AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+ copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
@@ -178,8 +187,8 @@ func (b IPv6) Encode(i *IPv6Fields) {
b.SetPayloadLength(i.PayloadLength)
b[nextHdr] = i.NextHeader
b[hopLimit] = i.HopLimit
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+ b.SetSourceAddress(i.SrcAddr)
+ b.SetDestinationAddress(i.DstAddr)
}
// IsValid performs basic validation on the packet.
@@ -219,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool {
return addr[0] == 0xff
}
+// IsV6UnicastAddress determines if the provided address is a valid IPv6
+// unicast (and specified) address. That is, IsV6UnicastAddress returns
+// true if addr contains IPv6AddressSize bytes, is not the unspecified
+// address and is not a multicast address.
+func IsV6UnicastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ // Must not be unspecified
+ if addr == IPv6Any {
+ return false
+ }
+
+ // Return if not a multicast.
+ return addr[0] != 0xff
+}
+
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index c1f454805..74412c894 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -27,6 +27,11 @@ const (
udpChecksum = 6
)
+const (
+ // UDPMaximumPacketSize is the largest possible UDP packet.
+ UDPMaximumPacketSize = 0xffff
+)
+
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index 3fc14bacd..cc5f531e2 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "iptables",
srcs = [
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c40744b8e..18adb2085 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -44,14 +44,12 @@ type Endpoint struct {
}
// New creates a new channel endpoint.
-func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
+func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
+ return &Endpoint{
C: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
-
- return stack.RegisterLinkEndpoint(e), e
}
// Drain removes all outbound packets from the channel and counts them.
@@ -135,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d786d8fdf..8fa9e3984 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -8,8 +9,8 @@ go_library(
"endpoint.go",
"endpoint_unsafe.go",
"mmap.go",
- "mmap_amd64.go",
- "mmap_amd64_unsafe.go",
+ "mmap_stub.go",
+ "mmap_unsafe.go",
"packet_dispatchers.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 77f988b9f..f80ac3435 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,6 +41,7 @@ package fdbased
import (
"fmt"
+ "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -81,6 +82,19 @@ const (
PacketMMap
)
+func (p PacketDispatchMode) String() string {
+ switch p {
+ case Readv:
+ return "Readv"
+ case RecvMMsg:
+ return "RecvMMsg"
+ case PacketMMap:
+ return "PacketMMap"
+ default:
+ return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ }
+}
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -114,6 +128,9 @@ type endpoint struct {
// gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
// disabled.
gsoMaxSize uint32
+
+ // wg keeps track of running goroutines.
+ wg sync.WaitGroup
}
// Options specify the details about the fd-based endpoint to be created.
@@ -161,11 +178,20 @@ type Options struct {
RXChecksumOffload bool
}
+// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT
+// support in the host kernel. This allows us to use multiple FD's to receive
+// from the same underlying NIC. The fanoutID needs to be the same for a given
+// set of FD's that point to the same NIC. Trying to set the PACKET_FANOUT
+// option for an FD with a fanoutID already in use by another FD for a different
+// NIC will return an EINVAL.
+var fanoutID = 1
+
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
-// open for the lifetime of the returned endpoint.
-func New(opts *Options) (tcpip.LinkEndpointID, error) {
+// open for the lifetime of the returned endpoint (until after the endpoint has
+// stopped being using and Wait returns).
+func New(opts *Options) (stack.LinkEndpoint, error) {
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
caps |= stack.CapabilityRXChecksumOffload
@@ -190,7 +216,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
if len(opts.FDs) == 0 {
- return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
+ return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
e := &endpoint{
@@ -207,12 +233,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
for i := 0; i < len(e.fds); i++ {
fd := e.fds[i]
if err := syscall.SetNonblock(fd, true); err != nil {
- return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
+ return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
}
isSocket, err := isSocketFD(fd)
if err != nil {
- return 0, err
+ return nil, err
}
if isSocket {
if opts.GSOMaxSize != 0 {
@@ -222,12 +248,16 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket)
if err != nil {
- return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err)
}
e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
}
- return stack.RegisterLinkEndpoint(e), nil
+ // Increment fanoutID to ensure that we don't re-use the same fanoutID for
+ // the next endpoint.
+ fanoutID++
+
+ return e, nil
}
func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) {
@@ -247,7 +277,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher
case *unix.SockaddrLinklayer:
// enable PACKET_FANOUT mode is the underlying socket is
// of type AF_PACKET.
- const fanoutID = 1
const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
fanoutArg := fanoutID | fanoutType<<16
if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
@@ -290,7 +319,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
for i := range e.inboundDispatchers {
- go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above.
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
}
}
@@ -320,6 +353,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
+// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
+// reading from its FD.
+func (e *endpoint) Wait() {
+ e.wg.Wait()
+}
+
// virtioNetHdr is declared in linux/virtio_net.h.
type virtioNetHdr struct {
flags uint8
@@ -435,14 +474,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf
}
// NewInjectable creates a new fd-based InjectableEndpoint.
-func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) {
+func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint {
syscall.SetNonblock(fd, true)
- e := &InjectableEndpoint{endpoint: endpoint{
+ return &InjectableEndpoint{endpoint: endpoint{
fds: []int{fd},
mtu: mtu,
caps: capabilities,
}}
-
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index e305252d6..04406bc9a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context {
}
opt.FDs = []int{fds[1]}
- epID, err := New(opt)
+ ep, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
}
- ep := stack.FindLinkEndpoint(epID).(*endpoint)
c := &context{
t: t,
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 2dca173c2..8bfeb97e4 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -12,12 +12,183 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !linux !amd64
+// +build linux,amd64 linux,arm64
package fdbased
-// Stubbed out version for non-linux/non-amd64 platforms.
+import (
+ "encoding/binary"
+ "syscall"
-func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- return nil, nil
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go
deleted file mode 100644
index 029f86a18..000000000
--- a/pkg/tcpip/link/fdbased/mmap_amd64.go
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux,amd64
-
-package fdbased
-
-import (
- "encoding/binary"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
-)
-
-const (
- tPacketAlignment = uintptr(16)
- tpStatusKernel = 0
- tpStatusUser = 1
- tpStatusCopy = 2
- tpStatusLosing = 4
-)
-
-// We overallocate the frame size to accommodate space for the
-// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
-//
-// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
-//
-// NOTE:
-// Frames need to be aligned at 16 byte boundaries.
-// BlockSize needs to be page aligned.
-//
-// For details see PACKET_MMAP setting constraints in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-const (
- tpFrameSize = 65536 + 128
- tpBlockSize = tpFrameSize * 32
- tpBlockNR = 1
- tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
-)
-
-// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
-// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
-func tPacketAlign(v uintptr) uintptr {
- return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
-}
-
-// tPacketReq is the tpacket_req structure as described in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-type tPacketReq struct {
- tpBlockSize uint32
- tpBlockNR uint32
- tpFrameSize uint32
- tpFrameNR uint32
-}
-
-// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
-type tPacketHdr []byte
-
-const (
- tpStatusOffset = 0
- tpLenOffset = 8
- tpSnapLenOffset = 12
- tpMacOffset = 16
- tpNetOffset = 18
- tpSecOffset = 20
- tpUSecOffset = 24
-)
-
-func (t tPacketHdr) tpLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpLenOffset:])
-}
-
-func (t tPacketHdr) tpSnapLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
-}
-
-func (t tPacketHdr) tpMac() uint16 {
- return binary.LittleEndian.Uint16(t[tpMacOffset:])
-}
-
-func (t tPacketHdr) tpNet() uint16 {
- return binary.LittleEndian.Uint16(t[tpNetOffset:])
-}
-
-func (t tPacketHdr) tpSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpSecOffset:])
-}
-
-func (t tPacketHdr) tpUSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpUSecOffset:])
-}
-
-func (t tPacketHdr) Payload() []byte {
- return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
-}
-
-// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
-// See: mmap_amd64_unsafe.go for implementation details.
-type packetMMapDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
- // ringBuffer is only used when PacketMMap dispatcher is used and points
- // to the start of the mmapped PACKET_RX_RING buffer.
- ringBuffer []byte
-
- // ringOffset is the current offset into the ring buffer where the next
- // inbound packet will be placed by the kernel.
- ringOffset int
-}
-
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
- hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
- for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
- if errno == syscall.EINTR {
- continue
- }
- return nil, rawfile.TranslateErrno(errno)
- }
- if hdr.tpStatus()&tpStatusCopy != 0 {
- // This frame is truncated so skip it after flipping the
- // buffer to the kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
- continue
- }
- }
-
- // Copy out the packet from the mmapped frame to a locally owned buffer.
- pkt := make([]byte, hdr.tpSnapLen())
- copy(pkt, hdr.Payload())
- // Release packet to kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
-}
-
-// dispatch reads packets from an mmaped ring buffer and dispatches them to the
-// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
- return false, err
- }
- var (
- p tcpip.NetworkProtocolNumber
- remote, local tcpip.LinkAddress
- )
- if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
- p = eth.Type()
- remote = eth.SourceAddress()
- local = eth.DestinationAddress()
- } else {
- // We don't get any indication of what the packet is, so try to guess
- // if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(pkt) {
- case header.IPv4Version:
- p = header.IPv4ProtocolNumber
- case header.IPv6Version:
- p = header.IPv6ProtocolNumber
- default:
- return true, nil
- }
- }
-
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
- return true, nil
-}
diff --git a/test/runtimes/runtimes.go b/pkg/tcpip/link/fdbased/mmap_stub.go
index 2568e07fe..67be52d67 100644
--- a/test/runtimes/runtimes.go
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package runtimes provides language tests for runsc runtimes.
-// Each test calls docker commands to start up a container for each supported runtime,
-// and tests that its respective language tests are behaving as expected, like
-// connecting to a port or looking at the output. The container is killed and deleted
-// at the end.
-package runtimes
+// +build !linux !amd64,!arm64
+
+package fdbased
+
+// Stubbed out version for non-linux/non-amd64/non-arm64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 47cb1d1cc..3894185ae 100644
--- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
package fdbased
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index ab6a53988..b36629d2c 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -32,8 +32,8 @@ type endpoint struct {
// New creates a new loopback endpoint. This link-layer endpoint just turns
// outbound packets into inbound packets.
-func New() tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{})
+func New() stack.LinkEndpoint {
+ return &endpoint{}
}
// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
@@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index ea12ef1ac..1bab380b0 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index a577a3d52..7c946101d 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -104,10 +104,16 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *
return endpoint.WriteRawPacket(dest, packet)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (m *InjectableEndpoint) Wait() {
+ for _, ep := range m.routes {
+ ep.Wait()
+ }
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
-func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) {
- e := &InjectableEndpoint{
+func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
+ return &InjectableEndpoint{
routes: routes,
}
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 174b9330f..3086fec00 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc
if err != nil {
t.Fatal("Failed to create socket pair:", err)
}
- _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
+ underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint}
- _, endpoint := NewInjectableEndpoint(routes)
+ endpoint := NewInjectableEndpoint(routes)
return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP
}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 6e3a7a9d7..2e8bc772a 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -6,8 +6,9 @@ go_library(
name = "rawfile",
srcs = [
"blockingpoll_amd64.s",
- "blockingpoll_amd64_unsafe.go",
- "blockingpoll_unsafe.go",
+ "blockingpoll_arm64.s",
+ "blockingpoll_noyield_unsafe.go",
+ "blockingpoll_yield_unsafe.go",
"errors.go",
"rawfile_unsafe.go",
],
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
new file mode 100644
index 000000000..b62888b93
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ BL ·callEntersyscallblock(SB)
+ MOVD fds+0(FP), R0
+ MOVD nfds+8(FP), R1
+ MOVD timeout+16(FP), R2
+ MOVD $0x0, R3 // sigmask parameter which isn't used here
+ MOVD $0x49, R8 // SYS_PPOLL
+ SVC
+ CMP $0xfffffffffffff001, R0
+ BLS ok
+ MOVD $-1, R1
+ MOVD R1, n+24(FP)
+ NEG R0, R0
+ MOVD R0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVD R0, n+24(FP)
+ MOVD $0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
index 84dc0e918..621ab8d29 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,!amd64
+// +build linux,!amd64,!arm64
package rawfile
@@ -22,7 +22,7 @@ import (
)
// BlockingPoll is just a stub function that forwards to the ppoll() system call
-// on non-amd64 platforms.
+// on non-amd64 and non-arm64 platforms.
func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 47039a446..dda3b10a6 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
// +build go1.12
// +build !go1.14
@@ -25,6 +25,12 @@ import (
_ "unsafe" // for go:linkname
)
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
//go:noescape
func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index f2998aa98..0a5ea3dc4 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 94725cb11..330ed5e94 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index 160a8f864..de1ce043d 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 834ea5c40..9e71d4edf 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -94,7 +94,7 @@ type endpoint struct {
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) {
+func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
e := &endpoint{
mtu: mtu,
bufferSize: bufferSize,
@@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc
}
if err := e.tx.init(bufferSize, &tx); err != nil {
- return 0, err
+ return nil, err
}
if err := e.rx.init(bufferSize, &rx); err != nil {
e.tx.cleanup()
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(e), nil
+ return e, nil
}
// Close frees all resources associated with the endpoint.
@@ -132,7 +132,8 @@ func (e *endpoint) Close() {
}
}
-// Wait waits until all workers have stopped after a Close() call.
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
func (e *endpoint) Wait() {
e.completed.Wait()
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 98036f367..0e9ba0846 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
if err != nil {
t.Fatalf("New failed: %v", err)
}
- c.ep = stack.FindLinkEndpoint(id).(*endpoint)
+ c.ep = ep.(*endpoint)
c.ep.Attach(c)
return c
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 36c8c46fc..e401dce44 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -58,10 +58,10 @@ type endpoint struct {
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
-func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
- })
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ return &endpoint{
+ lower: lower,
+ }
}
func zoneOffset() (int32, error) {
@@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
// snapLen is the maximum amount of a packet to be saved. Packets with a length
// less than or equal too snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
-func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) {
if err := writePCAPHeader(file, snapLen); err != nil {
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
+ return &endpoint{
+ lower: lower,
file: file,
maxPCAPLen: snapLen,
- }), nil
+ }, nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
@@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 2597d4b3e..0746dc8ec 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 3b6ac2ff7..5a1791cb5 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -40,11 +40,10 @@ type Endpoint struct {
// New creates a new waitable link-layer endpoint. It wraps around another
// endpoint and allows the caller to block new write/dispatch calls and wait for
// the inflight ones to finish before returning.
-func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
- lower: stack.FindLinkEndpoint(lower),
+func New(lower stack.LinkEndpoint) *Endpoint {
+ return &Endpoint{
+ lower: lower,
}
- return stack.RegisterLinkEndpoint(e), e
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
@@ -121,3 +120,6 @@ func (e *Endpoint) WaitWrite() {
func (e *Endpoint) WaitDispatch() {
e.dispatchGate.Close()
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 56e18ecb0..ae23c96b7 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -70,9 +70,12 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Write and check that it goes through.
wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
@@ -97,7 +100,7 @@ func TestWaitWrite(t *testing.T) {
func TestWaitDispatch(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Check that attach happens.
wep.Attach(ep)
@@ -139,7 +142,7 @@ func TestOtherMethods(t *testing.T) {
hdrLen: hdrLen,
linkAddr: linkAddr,
}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
if v := wep.MTU(); v != mtu {
t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index f36f49453..9d16ff8c9 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index d95d44f56..df0d3a8c0 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ea7296e6a..6b1e854dc 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -16,9 +16,9 @@
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
-// To use it in the networking stack, pass arp.ProtocolName as one of the
-// network protocols when calling stack.New. Then add an "arp" address to
-// every NIC on the stack that should respond to ARP requests. That is:
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
@@ -33,9 +33,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ARP protocol name.
- ProtocolName = "arp"
-
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
@@ -82,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -204,8 +201,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 4c4b54469..88b57ec03 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -44,14 +44,19 @@ type testContext struct {
}
func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
const defaultMTU = 65536
- id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(256, defaultMTU, stackLinkAddr)
+ wep := stack.LinkEndpoint(ep)
+
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -73,7 +78,7 @@ func newTestContext(t *testing.T) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 118bfc763..2cad0a0b6 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,7 +1,8 @@
-package(licenses = ["notice"])
-
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
go_template_instance(
name = "reassembler_list",
@@ -27,6 +28,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/tcpip",
"//pkg/tcpip/buffer",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 1628a82be..6da5238ec 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -17,6 +17,7 @@
package fragmentation
import (
+ "fmt"
"log"
"sync"
"time"
@@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
// Process processes an incoming fragment belonging to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
@@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
f.mu.Unlock()
- res, done, consumed := r.process(first, last, more, vv)
-
+ res, done, consumed, err := r.process(first, last, more, vv)
+ if err != nil {
+ // We probably got an invalid sequence of fragments. Just
+ // discard the reassembler and move on.
+ f.mu.Lock()
+ f.release(r)
+ f.mu.Unlock()
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ }
f.mu.Lock()
f.size += consumed
if done {
@@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
}
f.mu.Unlock()
- return res, done
+ return res, done, nil
}
func (f *Fragmentation) release(r *reassembler) {
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 799798544..72c0f53be 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) {
t.Run(c.comment, func(t *testing.T) {
f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
for i, in := range c.in {
- vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if err != nil {
+ t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ }
if !reflect.DeepEqual(vv, c.out[i].vv) {
t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
}
@@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) {
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done := f.Process(0, 1, 1, false, vv(1, "1"))
+ _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ if err != nil {
+ t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ }
if done {
t.Errorf("Fragmentation does not respect the reassembling timeout.")
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 8037f734b..9e002e396 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
@@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
}
- return res, true, consumed
+ return res, true, consumed, nil
}
func (r *reassembler) tooOld(timeout time.Duration) bool {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6bbfcd97f..f644a8b08 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
@@ -169,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen
}
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -182,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -221,7 +230,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -319,7 +328,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
icmp.SetType(header.ICMPv4DstUnreachable)
icmp.SetCode(c.code)
- copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef})
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
// Create the inner IPv4 header.
ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
@@ -450,7 +460,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -539,7 +549,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
defer ep.Close()
- dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
}
@@ -559,10 +569,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
icmp.SetType(c.typ)
icmp.SetCode(c.code)
- copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
// Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
ip.Encode(&header.IPv6Fields{
PayloadLength: 100,
NextHeader: 10,
@@ -574,7 +585,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
frag.Encode(&header.IPv6FragmentFields{
NextHeader: 10,
FragmentOffset: *c.fragmentOffset,
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index be84fa63d..58e537aad 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 497164cbb..50b363dc4 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,8 +15,6 @@
package ipv4
import (
- "encoding/binary"
-
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -97,7 +95,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil {
+ if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
sent.Dropped.Increment()
return
}
@@ -117,7 +115,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
e.handleControl(stack.ControlPortUnreachable, 0, vv)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:]))
+ mtu := uint32(h.MTU())
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7a06f525..5cd895ff0 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -14,9 +14,9 @@
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
@@ -32,9 +32,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv4 protocol name.
- ProtocolName = "ipv4"
-
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -42,6 +39,9 @@ const (
// TotalLength field of the ipv4 header.
MaxTotalSize = 0xffff
+ // DefaultTTL is the default time-to-live value for this endpoint.
+ DefaultTTL = 64
+
// buckets is the number of identifier buckets.
buckets = 2048
)
@@ -53,6 +53,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
+ protocol *protocol
}
// NewEndpoint creates a new ipv4 endpoint.
@@ -64,6 +65,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
}
return e, nil
@@ -71,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
// DefaultTTL is the default time-to-live value for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return 255
+ return e.protocol.DefaultTTL()
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
@@ -197,21 +199,22 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + payload.Size())
id := uint32(0)
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
}
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
ID: uint16(id),
- TTL: ttl,
- Protocol: uint8(protocol),
+ TTL: params.TTL,
+ TOS: params.TOS,
+ Protocol: uint8(params.Protocol),
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -267,7 +270,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
if payload.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
}
ip.SetID(uint16(id))
}
@@ -294,6 +297,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
headerView := vv.First()
h := header.IPv4(headerView)
if !h.IsValid(vv.Size()) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
@@ -304,10 +308,31 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
if more || h.FragmentOffset() != 0 {
+ if vv.Size() == 0 {
+ // Drop the packet as it's marked as a fragment but has
+ // no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
// The packet is a fragment, let's try to reassemble it.
last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and vv.size() causes a wrap
+ // around resulting in last being less than the offset.
+ if last < h.FragmentOffset() {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
var ready bool
- vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ var err error
+ vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
if !ready {
return
}
@@ -325,14 +350,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {}
-type protocol struct{}
+type protocol struct {
+ ids []uint32
+ hashIV uint32
-// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
}
// Number returns the ipv4 protocol number.
@@ -358,12 +383,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -378,7 +425,7 @@ func calculateMTU(mtu uint32) uint32 {
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
@@ -386,22 +433,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-var (
- ids []uint32
- hashIV uint32
-)
-
-func init() {
- ids = make([]uint32, buckets)
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
- hashIV = r[buckets]
+ hashIV := r[buckets]
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+ return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 1b5a55bea..85ab0e3bc 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -33,20 +33,20 @@ import (
)
func TestExcludeBroadcast(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
const defaultMTU = 65536
- id, _ := channel.New(256, defaultMTU, "")
+ ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ ep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
@@ -184,15 +184,12 @@ type errorChannel struct {
// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
// will return successive errors from packetCollectorErrors until the list is
// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) {
- _, e := channel.New(size, mtu, linkAddr)
- ec := errorChannel{
- Endpoint: e,
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
+ return &errorChannel{
+ Endpoint: channel.New(size, mtu, linkAddr),
Ch: make(chan packetInfo, size),
packetCollectorErrors: packetCollectorErrors,
}
-
- return stack.RegisterLinkEndpoint(e), &ec
}
// packetInfo holds all the information about an outbound packet.
@@ -241,10 +238,11 @@ type context struct {
func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
// Make the packet and write it.
- s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
- _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- linkEPId := stack.RegisterLinkEndpoint(linkEP)
- s.CreateNIC(1, linkEPId)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ s.CreateNIC(1, ep)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -266,7 +264,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
}
return context{
Route: r,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -304,7 +302,7 @@ func TestFragmentation(t *testing.T) {
Payload: payload.Clone([]buffer.View{}),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42)
+ err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
if err != nil {
t.Errorf("err got %v, want %v", err, nil)
}
@@ -351,7 +349,7 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42)
+ err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
@@ -368,3 +366,117 @@ func TestFragmentationErrors(t *testing.T) {
})
}
}
+
+func TestInvalidFragments(t *testing.T) {
+ // These packets have both IHL and TotalLength set to 0.
+ testCases := []struct {
+ name string
+ packets [][]byte
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ "ihl_totallen_zero_valid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ "ihl_totallen_zero_invalid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ // Total Length of 37(20 bytes IP header + 17 bytes of
+ // payload)
+ // Frag Offset of 0x1ffe = 8190*8 = 65520
+ // Leading to the fragment end to be past 65535.
+ "ihl_totallen_valid_invalid_frag_offset_1",
+ [][]byte{
+ {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ // The following 3 tests were found by running a fuzzer and were
+ // triggering a panic in the IPv4 reassembler code.
+ {
+ "ihl_less_than_ipv4_minimum_size_1",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_2",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_3",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "fragment_with_short_total_len_extra_payload",
+ [][]byte{
+ {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ {
+ "multiple_fragments_with_more_fragments_set_to_false",
+ [][]byte{
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ },
+ 1,
+ 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ const nicid tcpip.NICID = 42
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ },
+ })
+
+ var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
+ var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
+ ep := channel.New(10, 1500, linkAddr)
+ s.CreateNIC(nicid, sniffer.New(ep))
+
+ for _, pkt := range tc.packets {
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}))
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index fae7f4507..f06622a8b 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -23,7 +24,11 @@ go_library(
go_test(
name = "ipv6_test",
size = "small",
- srcs = ["icmp_test.go"],
+ srcs = [
+ "icmp_test.go",
+ "ipv6_test.go",
+ "ndp_test.go",
+ ],
embed = [":ipv6"],
deps = [
"//pkg/tcpip",
@@ -33,6 +38,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 5e6a59e91..f543ceb92 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,14 +15,21 @@
package ipv6
import (
- "encoding/binary"
-
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+const (
+ // ndpHopLimit is the expected IP hop limit value of 255 for received
+ // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
+ // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
+ // drop the NDP packet. All outgoing NDP packets must use this value for
+ // its IP hop limit field.
+ ndpHopLimit = 255
+)
+
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
@@ -73,6 +80,21 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
}
h := header.ICMPv6(v)
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255.
+ switch h.Type() {
+ case header.ICMPv6NeighborSolicit,
+ header.ICMPv6NeighborAdvert,
+ header.ICMPv6RouterSolicit,
+ header.ICMPv6RouterAdvert,
+ header.ICMPv6RedirectMsg:
+ if header.IPv6(netHeader).HopLimit() != ndpHopLimit {
+ received.Invalid.Increment()
+ return
+ }
+ }
+
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv6PacketTooBig:
@@ -82,7 +104,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
return
}
vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ mtu := h.MTU()
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
case header.ICMPv6DstUnreachable:
@@ -99,19 +121,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
-
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
}
-
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
@@ -132,9 +150,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ // TODO(tamird/ghanan): there exists an explicit NDP option that is
+ // used to update the neighbor table with link addresses for a
+ // neighbor from an NS (see the Source Link Layer option RFC
+ // 4861 section 4.6.1 and section 7.2.3).
+ //
+ // Furthermore, the entirety of NDP handling here seems to be
+ // contradicted by RFC 4861.
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
- if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ // RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
+ //
+ // 7.1.2. Validation of Neighbor Advertisements
+ //
+ // The IP Hop Limit field has a value of 255, i.e., the packet
+ // could not possibly have been forwarded by a router.
+ if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ndpHopLimit, TOS: stack.DefaultTOS}); err != nil {
sent.Dropped.Increment()
return
}
@@ -146,7 +179,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -158,14 +191,13 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
-
vv.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv6EchoReply)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
- if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
+ if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
sent.Dropped.Increment()
return
}
@@ -235,14 +267,14 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
pkt[icmpV6LengthOffset] = 1
copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(hdr.UsedLength())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: defaultIPv6HopLimit,
+ HopLimit: ndpHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -274,24 +306,3 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
return "", false
}
-
-func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := header.Checksum([]byte(src), 0)
- xsum = header.Checksum([]byte(dst), xsum)
- var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
- xsum = header.Checksum(upperLayerLength[:], xsum)
- xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
- for _, v := range vv.Views() {
- xsum = header.Checksum(v, xsum)
- }
-
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^header.Checksum(h, xsum)
- h[2], h[3] = h2, h3
-
- return xsum
-}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d0dc72506..dd3c4d7c4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,7 +15,6 @@
package ipv6
import (
- "fmt"
"reflect"
"strings"
"testing"
@@ -81,10 +80,12 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li
}
func TestICMPCounts(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
{
- id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
@@ -142,7 +143,7 @@ func TestICMPCounts(t *testing.T) {
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: r.DefaultTTL(),
+ HopLimit: ndpHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -153,7 +154,7 @@ func TestICMPCounts(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
handleIPv6Payload(hdr)
}
@@ -177,13 +178,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
- switch v.Kind() {
- case reflect.Ptr:
- f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter))
- case reflect.Struct:
+ if s, ok := v.Interface().(*tcpip.StatCounter); ok {
+ f(t.Field(i).Name, s)
+ } else {
visitStats(v, f)
- default:
- panic(fmt.Sprintf("unexpected type %s", v.Type()))
}
}
}
@@ -206,41 +204,38 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
- s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
- s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
}
const defaultMTU = 65536
- _, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
- c.linkEP0 = linkEP0
- wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0}
- id0 := stack.RegisterLinkEndpoint(wrappedEP0)
+ c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+
+ wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
- id0 = sniffer.New(id0)
+ wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, id0); err != nil {
+ if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil {
- t.Fatalf("AddAddress sn lladdr0: %v", err)
- }
- _, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
- c.linkEP1 = linkEP1
- wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1}
- id1 := stack.RegisterLinkEndpoint(wrappedEP1)
- if err := c.s1.CreateNIC(1, id1); err != nil {
+ c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
+ if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil {
- t.Fatalf("AddAddress sn lladdr1: %v", err)
- }
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
if err != nil {
@@ -321,7 +316,7 @@ func TestLinkResolution(t *testing.T) {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
payload := tcpip.SlicePayload(hdr.View())
// We can't send our payload directly over the route because that
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331a8bdaa..cd1e34085 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -14,13 +14,15 @@
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -28,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv6 protocol name.
- ProtocolName = "ipv6"
-
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -38,9 +37,9 @@ const (
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
- // defaultIPv6HopLimit is the default hop limit for IPv6 Packets
- // egressed by Netstack.
- defaultIPv6HopLimit = 255
+ // DefaultTTL is the default hop limit for IPv6 Packets egressed by
+ // Netstack.
+ DefaultTTL = 64
)
type endpoint struct {
@@ -50,11 +49,12 @@ type endpoint struct {
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
+ protocol *protocol
}
// DefaultTTL is the default hop limit for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return 255
+ return e.protocol.DefaultTTL()
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
@@ -98,13 +98,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
length := uint16(hdr.UsedLength() + payload.Size())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
- NextHeader: uint8(protocol),
- HopLimit: ttl,
+ NextHeader: uint8(params.Protocol),
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -158,14 +159,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Close cleans up resources associated with the endpoint.
func (*endpoint) Close() {}
-type protocol struct{}
-
-// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+type protocol struct {
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
}
// Number returns the ipv6 protocol number.
@@ -198,17 +196,40 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
+ protocol: p,
}, nil
}
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -221,8 +242,7 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{defaultTTL: DefaultTTL}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
new file mode 100644
index 000000000..deaa9b7f3
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -0,0 +1,266 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ // The least significant 3 bytes are the same as addr2 so both addr2 and
+ // addr3 will have the same solicited-node address.
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+)
+
+// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
+// expected Neighbor Advertisement received count after receiving the packet.
+func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ // Receive ICMP packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+
+ if got := stats.NeighborAdvert.Value(); got != want {
+ t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
+ }
+}
+
+// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
+// expected UDP received count after receiving the packet.
+func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+
+ // Receive UDP Packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+
+ // UDP pseudo-header checksum.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
+
+ // UDP checksum
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stat := s.Stats().UDP.PacketsReceived
+
+ if got := stat.Value(); got != want {
+ t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
+ }
+}
+
+// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
+// UDP packets destined to the IPv6 link-local all-nodes multicast address.
+func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should receive a packet destined to the all-nodes
+ // multicast address.
+ test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
+ })
+ }
+}
+
+// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
+// packets destined to the IPv6 solicited-node address of an assigned IPv6
+// address.
+func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ snmc := header.SolicitedNodeAddr(addr2)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as we haven't added
+ // those addresses.
+ test.rxf(t, s, e, addr1, snmc, 0)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ }
+
+ // Should receive a packet destined to the solicited
+ // node address of addr2/addr3 now that we have added
+ // added addr2.
+ test.rxf(t, s, e, addr1, snmc, 1)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have added addr3.
+ test.rxf(t, s, e, addr1, snmc, 2)
+
+ if err := s.RemoveAddress(1, addr2); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have removed addr2.
+ test.rxf(t, s, e, addr1, snmc, 3)
+
+ if err := s.RemoveAddress(1, addr3); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as both of them got
+ // removed.
+ test.rxf(t, s, e, addr1, snmc, 3)
+ })
+ }
+}
+
+// TestAddIpv6Address tests adding IPv6 addresses.
+func TestAddIpv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ }{
+ // This test is in response to b/140943433.
+ {
+ "Nil",
+ tcpip.Address([]byte(nil)),
+ },
+ {
+ "ValidUnicast",
+ addr1,
+ },
+ {
+ "ValidLinkLocalUnicast",
+ lladdr0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
new file mode 100644
index 000000000..e30791fe3
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -0,0 +1,181 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+// setupStackAndEndpoint creates a stack with a single NIC with a link-local
+// address llladdr and an IPv6 endpoint to a remote with link-local address
+// rlladdr
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep
+}
+
+// TestHopLimitValidation is a test that makes sure that NDP packets are only
+// received if their IP header's hop limit is set to 255.
+func TestHopLimitValidation(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
+
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
+
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep, r
+ }
+
+ handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) {
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(r, hdr.View().ToVectorisedView())
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ }},
+ {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ }},
+ {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ }},
+ {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ }},
+ {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ }},
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ // Should not have received any ICMPv6 packets with
+ // type = typ.typ.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Receive the NDP packet with an invalid hop limit
+ // value.
+ handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r)
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ // Rx count of NDP packet of type typ.typ should not
+ // have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Receive the NDP packet with a valid hop limit value.
+ handleIPv6Payload(hdr, ndpHopLimit, ep, &r)
+
+ // Rx count of NDP packet of type typ.typ should have
+ // increased.
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 989058413..11efb4e44 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 315780c0c..30cea8996 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -19,6 +19,7 @@ import (
"math"
"math/rand"
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -27,6 +28,10 @@ const (
// FirstEphemeral is the first ephemeral port.
FirstEphemeral = 16000
+ // numEphemeralPorts it the mnumber of available ephemeral ports to
+ // Netstack.
+ numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+
anyIPAddress tcpip.Address = ""
)
@@ -40,6 +45,13 @@ type portDescriptor struct {
type PortManager struct {
mu sync.RWMutex
allocatedPorts map[portDescriptor]bindAddresses
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
}
type portNode struct {
@@ -47,43 +59,76 @@ type portNode struct {
refs int
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]portNode
+// deviceNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type deviceNode map[tcpip.NICID]portNode
-// isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
- if addr == anyIPAddress {
- if len(b) == 0 {
- return true
- }
+// isAvailable checks whether binding is possible by device. If not binding to a
+// device, check against all portNodes. If binding to a specific device, check
+// against the unspecified device and the provided device.
+func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+ if bindToDevice == 0 {
+ // Trying to binding all devices.
if !reuse {
+ // Can't bind because the (addr,port) is already bound.
return false
}
- for _, n := range b {
- if !n.reuse {
+ for _, p := range d {
+ if !p.reuse {
+ // Can't bind because the (addr,port) was previously bound without reuse.
return false
}
}
return true
}
- // If all addresses for this portDescriptor are already bound, no
- // address is available.
- if n, ok := b[anyIPAddress]; ok {
- if !reuse {
+ if p, ok := d[0]; ok {
+ if !reuse || !p.reuse {
return false
}
- if !n.reuse {
+ }
+
+ if p, ok := d[bindToDevice]; ok {
+ if !reuse || !p.reuse {
return false
}
}
- if n, ok := b[addr]; ok {
- if !reuse {
+ return true
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]deviceNode
+
+// isAvailable checks whether an IP address is available to bind to. If the
+// address is the "any" address, check all other addresses. Otherwise, just
+// check against the "any" address and the provided address.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+ if addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no conflicts
+ // with all addresses.
+ for _, d := range b {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+ return true
+ }
+
+ // Check that there is no conflict with the "any" address.
+ if d, ok := b[anyIPAddress]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
return false
}
- return n.reuse
}
+
+ // Check that this is no conflict with the provided address.
+ if d, ok := b[addr]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+
return true
}
@@ -97,11 +142,40 @@ func NewPortManager() *PortManager {
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
- count := uint16(math.MaxUint16 - FirstEphemeral + 1)
- offset := uint16(rand.Int31n(int32(count)))
+ offset := uint32(rand.Int31n(numEphemeralPorts))
+ return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+}
+
+// portHint atomically reads and returns the s.hint value.
+func (s *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&s.hint)
+}
+
+// incPortHint atomically increments s.hint by 1.
+func (s *PortManager) incPortHint() {
+ atomic.AddUint32(&s.hint, 1)
+}
- for i := uint16(0); i < count; i++ {
- port = FirstEphemeral + (offset+i)%count
+// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// iterates over all ephemeral ports, allowing the caller to decide whether a
+// given port is suitable for its needs and stopping when a port is found or an
+// error occurs.
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+ if err == nil {
+ s.incPortHint()
+ }
+ return p, err
+
+}
+
+// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
+// and iterates over the number of ports specified by count and allows the
+// caller to decide whether a given port is suitable for its needs, and stopping
+// when a port is found or an error occurs.
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ for i := uint32(0); i < count; i++ {
+ port = uint16(FirstEphemeral + (offset+i)%count)
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -116,17 +190,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse) {
+ if !addrs.isAvailable(addr, reuse, bindToDevice) {
return false
}
}
@@ -138,14 +212,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -153,13 +227,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
return false
}
@@ -171,11 +245,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
- if n, ok := m[addr]; ok {
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ if n, ok := d[bindToDevice]; ok {
n.refs++
- m[addr] = n
+ d[bindToDevice] = n
} else {
- m[addr] = portNode{reuse: reuse, refs: 1}
+ d[bindToDevice] = portNode{reuse: reuse, refs: 1}
}
}
@@ -184,22 +263,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
- n, ok := m[addr]
+ d, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n, ok := d[bindToDevice]
if !ok {
continue
}
n.refs--
+ d[bindToDevice] = n
if n.refs == 0 {
+ delete(d, bindToDevice)
+ }
+ if len(d) == 0 {
delete(m, addr)
- } else {
- m[addr] = n
}
if len(m) == 0 {
delete(s.allocatedPorts, desc)
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 689401661..19f4833fc 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,6 +15,7 @@
package ports
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,6 +35,7 @@ type portReserveTestAction struct {
want *tcpip.Error
reuse bool
release bool
+ device tcpip.NICID
}
func TestPortReservation(t *testing.T) {
@@ -100,6 +102,112 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, release: true},
{port: 24, ip: anyIPAddress, reuse: false, want: nil},
},
+ }, {
+ tname: "bind twice with device fails",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 3, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 1, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 2, want: nil},
+ },
+ }, {
+ tname: "bind to device and then without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ },
+ }, {
+ tname: "binding with reuse and device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "mixing reuse and not reuse by binding to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: nil},
+ },
+ }, {
+ tname: "can't bind to 0 after mixing reuse and not reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind and release",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+
+ // Release the bind to device 0 and try again.
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuse once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "release an unreserved device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ // The below don't exist.
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ // Release all.
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ },
},
} {
t.Run(test.tname, func(t *testing.T) {
@@ -108,12 +216,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
@@ -125,7 +233,6 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
- pm := NewPortManager()
customErr := &tcpip.Error{}
for _, test := range []struct {
name string
@@ -169,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
}
})
}
}
+
+func TestPickEphemeralPortStable(t *testing.T) {
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == FirstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: FirstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < FirstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index e2021cd15..2239c1e66 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -126,7 +126,10 @@ func main() {
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
@@ -138,11 +141,11 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
+ linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil {
+ if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 1716be285..bca73cbb1 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -111,7 +111,10 @@ func main() {
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
@@ -128,7 +131,7 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{
+ linkEP, err := fdbased.New(&fdbased.Options{
FDs: []int{fd},
MTU: mtu,
EthernetHeader: *tap,
@@ -137,7 +140,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, linkID); err != nil {
+ if err := s.CreateNIC(1, linkEP); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
index 76b5f4ffa..29b7d761c 100644
--- a/pkg/tcpip/seqnum/BUILD
+++ b/pkg/tcpip/seqnum/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "seqnum",
srcs = ["seqnum.go"],
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 9986b4be3..6a78432c9 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,11 +1,27 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+go_template_instance(
+ name = "linkaddrentry_list",
+ out = "linkaddrentry_list.go",
+ package = "stack",
+ prefix = "linkAddrEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*linkAddrEntry",
+ "Linker": "*linkAddrEntry",
+ },
+)
go_library(
name = "stack",
srcs = [
+ "icmp_rate_limit.go",
"linkaddrcache.go",
+ "linkaddrentry_list.go",
"nic.go",
"registration.go",
"route.go",
@@ -19,6 +35,7 @@ go_library(
],
deps = [
"//pkg/ilist",
+ "//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -28,6 +45,7 @@ go_library(
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
+ "@org_golang_x_time//rate:go_default_library",
],
)
@@ -36,6 +54,7 @@ go_test(
size = "small",
srcs = [
"stack_test.go",
+ "transport_demuxer_test.go",
"transport_test.go",
],
deps = [
@@ -46,6 +65,9 @@ go_test(
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
@@ -60,3 +82,11 @@ go_test(
"//pkg/tcpip",
],
)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "linkaddrentry_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
new file mode 100644
index 000000000..3a20839da
--- /dev/null
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "golang.org/x/time/rate"
+)
+
+const (
+ // icmpLimit is the default maximum number of ICMP messages permitted by this
+ // rate limiter.
+ icmpLimit = 1000
+
+ // icmpBurst is the default number of ICMP messages that can be sent in a single
+ // burst.
+ icmpBurst = 50
+)
+
+// ICMPRateLimiter is a global rate limiter that controls the generation of
+// ICMP messages generated by the stack.
+type ICMPRateLimiter struct {
+ *rate.Limiter
+}
+
+// NewICMPRateLimiter returns a global rate limiter for controlling the rate
+// at which ICMP messages are generated by the stack.
+func NewICMPRateLimiter() *ICMPRateLimiter {
+ return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 77bb0ccb9..267df60d1 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -42,10 +42,11 @@ type linkAddrCache struct {
// resolved before failing.
resolutionAttempts int
- mu sync.Mutex
- cache map[tcpip.FullAddress]*linkAddrEntry
- next int // array index of next available entry
- entries [linkAddrCacheSize]linkAddrEntry
+ cache struct {
+ sync.Mutex
+ table map[tcpip.FullAddress]*linkAddrEntry
+ lru linkAddrEntryList
+ }
}
// entryState controls the state of a single entry in the cache.
@@ -60,9 +61,6 @@ const (
// failed means that address resolution timed out and the address
// could not be resolved.
failed
- // expired means that the cache entry has expired and the address must be
- // resolved again.
- expired
)
// String implements Stringer.
@@ -74,8 +72,6 @@ func (s entryState) String() string {
return "ready"
case failed:
return "failed"
- case expired:
- return "expired"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -84,64 +80,46 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ linkAddrEntryEntry
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
// wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of 'incomplete' these waiters are notified.
+ // state transitions out of incomplete these waiters are notified.
wakers map[*sleep.Waker]struct{}
+ // done is used to allow callers to wait on address resolution. It is nil iff
+ // s is incomplete and resolution is not yet in progress.
done chan struct{}
}
-func (e *linkAddrEntry) state() entryState {
- if e.s != expired && time.Now().After(e.expiration) {
- // Force the transition to ensure waiters are notified.
- e.changeState(expired)
- }
- return e.s
-}
-
-func (e *linkAddrEntry) changeState(ns entryState) {
- if e.s == ns {
- return
- }
-
- // Validate state transition.
- switch e.s {
- case incomplete:
- // All transitions are valid.
- case ready, failed:
- if ns != expired {
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- }
- case expired:
- // Terminal state.
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- default:
- panic(fmt.Sprintf("invalid state: %s", e.s))
- }
-
+// changeState sets the entry's state to ns, notifying any waiters.
+//
+// The entry's expiration is bumped up to the greater of itself and the passed
+// expiration; the zero value indicates immediate expiration, and is set
+// unconditionally - this is an implementation detail that allows for entries
+// to be reused.
+func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
// Notify whoever is waiting on address resolution when transitioning
- // out of 'incomplete'.
- if e.s == incomplete {
+ // out of incomplete.
+ if e.s == incomplete && ns != incomplete {
for w := range e.wakers {
w.Assert()
}
e.wakers = nil
- if e.done != nil {
- close(e.done)
+ if ch := e.done; ch != nil {
+ close(ch)
}
+ e.done = nil
}
- e.s = ns
-}
-func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
- if w != nil {
- e.wakers[w] = struct{}{}
+ if expiration.IsZero() || expiration.After(e.expiration) {
+ e.expiration = expiration
}
+ e.s = ns
}
func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
@@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
- if ok {
- s := entry.state()
- if s != expired && entry.linkAddr == v {
- // Disregard repeated calls.
- return
- }
- // Check if entry is waiting for address resolution.
- if s == incomplete {
- entry.linkAddr = v
- } else {
- // Otherwise create a new entry to replace it.
- entry = c.makeAndAddEntry(k, v)
- }
- } else {
- entry = c.makeAndAddEntry(k, v)
- }
+ // Calculate expiration time before acquiring the lock, since expiration is
+ // relative to the time when information was learned, rather than when it
+ // happened to be inserted into the cache.
+ expiration := time.Now().Add(c.ageLimit)
- entry.changeState(ready)
+ c.cache.Lock()
+ entry := c.getOrCreateEntryLocked(k)
+ entry.linkAddr = v
+
+ entry.changeState(ready, expiration)
+ c.cache.Unlock()
}
-// makeAndAddEntry is a helper function to create and add a new
-// entry to the cache map and evict older entry as needed.
-func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
- // Take over the next entry.
- entry := &c.entries[c.next]
- if c.cache[entry.addr] == entry {
- delete(c.cache, entry.addr)
+// getOrCreateEntryLocked retrieves a cache entry associated with k. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+ if entry, ok := c.cache.table[k]; ok {
+ c.cache.lru.Remove(entry)
+ c.cache.lru.PushFront(entry)
+ return entry
}
+ var entry *linkAddrEntry
+ if len(c.cache.table) == linkAddrCacheSize {
+ entry = c.cache.lru.Back()
- // Mark the soon-to-be-replaced entry as expired, just in case there is
- // someone waiting for address resolution on it.
- entry.changeState(expired)
+ delete(c.cache.table, entry.addr)
+ c.cache.lru.Remove(entry)
- *entry = linkAddrEntry{
- addr: k,
- linkAddr: v,
- expiration: time.Now().Add(c.ageLimit),
- wakers: make(map[*sleep.Waker]struct{}),
- done: make(chan struct{}),
+ // Wake waiters and mark the soon-to-be-reused entry as expired. Note
+ // that the state passed doesn't matter when the zero time is passed.
+ entry.changeState(failed, time.Time{})
+ } else {
+ entry = new(linkAddrEntry)
}
- c.cache[k] = entry
- c.next = (c.next + 1) % len(c.entries)
+ *entry = linkAddrEntry{
+ addr: k,
+ s: incomplete,
+ }
+ c.cache.table[k] = entry
+ c.cache.lru.PushFront(entry)
return entry
}
@@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
}
- c.mu.Lock()
- defer c.mu.Unlock()
- if entry, ok := c.cache[k]; ok {
- switch s := entry.state(); s {
- case expired:
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return "", nil, tcpip.ErrNoLinkAddress
- case incomplete:
- // Address resolution is still in progress.
- entry.maybeAddWaker(waker)
- return "", entry.done, tcpip.ErrWouldBlock
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry := c.getOrCreateEntryLocked(k)
+ switch s := entry.s; s {
+ case ready, failed:
+ if !time.Now().After(entry.expiration) {
+ // Not expired.
+ switch s {
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
- }
- if linkRes == nil {
- return "", nil, tcpip.ErrNoLinkAddress
- }
+ entry.changeState(incomplete, time.Time{})
+ fallthrough
+ case incomplete:
+ if waker != nil {
+ if entry.wakers == nil {
+ entry.wakers = make(map[*sleep.Waker]struct{})
+ }
+ entry.wakers[waker] = struct{}{}
+ }
- // Add 'incomplete' entry in the cache to mark that resolution is in progress.
- e := c.makeAndAddEntry(k, "")
- e.maybeAddWaker(waker)
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ }
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ entry.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ }
- return "", e.done, tcpip.ErrWouldBlock
+ return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
// removeWaker removes a waker previously added through get().
func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.cache.Lock()
+ defer c.cache.Unlock()
- if entry, ok := c.cache[k]; ok {
+ if entry, ok := c.cache.table[k]; ok {
entry.removeWaker(waker)
}
}
@@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
select {
- case <-time.After(c.resolutionTimeout):
- if stop := c.checkLinkRequest(k, i); stop {
+ case now := <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(now, k, i); stop {
return
}
case <-done:
@@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
// can stop, false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry, ok := c.cache.table[k]
if !ok {
// Entry was evicted from the cache.
return true
}
-
- switch s := entry.state(); s {
- case ready, failed, expired:
+ switch s := entry.s; s {
+ case ready, failed:
// Entry was made ready by resolver or failed. Either way we're done.
- return true
case incomplete:
- if attempt+1 >= c.resolutionAttempts {
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed)
- return true
+ if attempt+1 < c.resolutionAttempts {
+ // No response yet, need to send another ARP request.
+ return false
}
- // No response yet, need to send another ARP request.
- return false
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed, now.Add(c.ageLimit))
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
+ return true
}
func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
- return &linkAddrCache{
+ c := &linkAddrCache{
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
- cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
}
+ c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 924f4d240..9946b8fe8 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -29,25 +30,34 @@ type testaddr struct {
linkAddr tcpip.LinkAddress
}
-var testaddrs []testaddr
+var testAddrs = func() []testaddr {
+ var addrs []testaddr
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ addrs = append(addrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+ return addrs
+}()
type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
+ cache *linkAddrCache
+ delay time.Duration
+ onLinkAddressRequest func()
}
func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- go func() {
- if r.delay > 0 {
- time.Sleep(r.delay)
- }
- r.fakeRequest(addr)
- }()
+ time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
return nil
}
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testaddrs {
+ for _, ta := range testAddrs {
if ta.addr.Addr == addr {
r.cache.add(ta.addr, ta.linkAddr)
break
@@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
}
}
-func init() {
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- testaddrs = append(testaddrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
-}
-
func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testaddrs) - 1; i >= 0; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= 0; i-- {
+ e := testAddrs[i]
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
@@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) {
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testaddrs[i]
+ e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
@@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
- for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
+ e := testAddrs[i]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
@@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) {
for r := 0; r < 16; r++ {
wg.Add(1)
go func() {
- for _, e := range testaddrs {
+ for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
@@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
- e = testaddrs[0]
+ e = testAddrs[0]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
@@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) {
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
@@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) {
func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
@@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) {
func TestCacheResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testaddrs {
+ for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
@@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) {
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
+ var requestCount uint32
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
+ }
+
// First, sanity check that resolution is working...
- e := testaddrs[0]
+ e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
+ before := atomic.LoadUint32(&requestCount)
+
e.addr.Addr += "2"
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
+
+ if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
}
func TestCacheResolutionTimeout(t *testing.T) {
@@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) {
c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
- e := testaddrs[0]
+ e := testAddrs[0]
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 4ef85bdfb..f64bbf6eb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -34,15 +34,13 @@ type NIC struct {
linkEP LinkEndpoint
loopback bool
- demux *transportDemuxer
-
- mu sync.RWMutex
- spoofing bool
- promiscuous bool
- primary map[tcpip.NetworkProtocolNumber]*ilist.List
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- subnets []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ addressRanges []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
stats NICStats
}
@@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
name: name,
linkEP: ep,
loopback: loopback,
- demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -102,6 +99,35 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
}
}
+// enable enables the NIC. enable will attach the link to its LinkEndpoint and
+// join the IPv6 All-Nodes Multicast address (ff02::1).
+func (n *NIC) enable() *tcpip.Error {
+ n.attachLinkEndpoint()
+
+ // Create an endpoint to receive broadcast packets on this interface.
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ if err := n.AddAddress(tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
+ }, NeverPrimaryEndpoint); err != nil {
+ return err
+ }
+ }
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ }
+
+ return nil
+}
+
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
// to start delivering packets.
func (n *NIC) attachLinkEndpoint() {
@@ -129,37 +155,6 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- var r *referencedNetworkEndpoint
-
- // Check for a primary endpoint.
- if list, ok := n.primary[protocol]; ok {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- if ref.holdsInsertRef && ref.tryIncRef() {
- r = ref
- break
- }
- }
-
- }
-
- if r == nil {
- return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress
- }
-
- addressWithPrefix := tcpip.AddressWithPrefix{
- Address: r.ep.ID().LocalAddress,
- PrefixLen: r.ep.PrefixLen(),
- }
- r.decRef()
-
- return addressWithPrefix, nil
-}
-
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
@@ -178,7 +173,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
case header.IPv4Broadcast, header.IPv4Any:
continue
}
- if r.tryIncRef() {
+ if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
}
@@ -197,22 +192,44 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// getRefEpOrCreateTemp returns the referenced network endpoint for the given
// protocol and address. If none exists a temporary one may be created if
-// requested.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint {
+// we are in promiscuous mode or spoofing.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
+ if ref, ok := n.endpoints[id]; ok {
+ // An endpoint with this id exists, check if it can be used and return it.
+ switch ref.getKind() {
+ case permanentExpired:
+ if !spoofingOrPromiscuous {
+ n.mu.RUnlock()
+ return nil
+ }
+ fallthrough
+ case temporary, permanent:
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
}
- // The address was not found, create a temporary one if requested by the
- // caller or if the address is found in the NIC's subnets.
- createTempEP := allowTemp
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
if !createTempEP {
- for _, sn := range n.subnets {
+ for _, sn := range n.addressRanges {
+ // Skip the subnet address.
+ if address == sn.ID() {
+ continue
+ }
+ // For now just skip the broadcast address, until we support it.
+ // FIXME(b/137608825): Add support for sending/receiving directed
+ // (subnet) broadcast.
+ if address == sn.Broadcast() {
+ continue
+ }
if sn.Contains(address) {
createTempEP = true
break
@@ -230,34 +247,70 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.Unlock()
- return ref
+ if ref, ok := n.endpoints[id]; ok {
+ // No need to check the type as we are ok with expired endpoints at this
+ // point.
+ if ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once the
+ // lock is released. Remove it here so we can create a new (temporary) one.
+ // The removal logic waiting for the lock handles this case.
+ n.removeEndpointLocked(ref)
}
+ // Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.mu.Unlock()
return nil
}
-
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, true)
-
- if ref != nil {
- ref.holdsInsertRef = false
- }
+ }, peb, temporary)
n.mu.Unlock()
return ref
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if ref, ok := n.endpoints[id]; ok {
+ switch ref.getKind() {
+ case permanent:
+ // The NIC already have a permanent endpoint with that address.
+ return nil, tcpip.ErrDuplicateAddress
+ case permanentExpired, temporary:
+ // Promote the endpoint to become permanent.
+ if ref.tryIncRef() {
+ ref.setKind(permanent)
+ return ref, nil
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once
+ // the lock is released. Remove it here so we can create a new
+ // (permanent) one. The removal logic waiting for the lock handles this
+ // case.
+ n.removeEndpointLocked(ref)
+ }
+ }
+ return n.addAddressLocked(protocolAddress, peb, permanent)
+}
+
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // TODO(b/141022673): Validate IP address before adding them.
+
+ // Sanity check.
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if _, ok := n.endpoints[id]; ok {
+ // Endpoint already exists.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
@@ -268,22 +321,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if err != nil {
return nil, err
}
-
- id := *ep.ID()
- if ref, ok := n.endpoints[id]; ok {
- if !replace {
- return nil, tcpip.ErrDuplicateAddress
- }
-
- n.removeEndpointLocked(ref)
- }
-
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- holdsInsertRef: true,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
}
// Set up cache if link address resolution exists for this protocol.
@@ -293,6 +336,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
+ // If we are adding an IPv6 unicast address, join the solicited-node
+ // multicast address.
+ if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) {
+ snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
+ if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
+ return nil, err
+ }
+ }
+
n.endpoints[id] = ref
l, ok := n.primary[protocolAddress.Protocol]
@@ -316,18 +368,26 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocolAddress, peb, false)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb)
n.mu.Unlock()
return err
}
-// Addresses returns the addresses associated with this NIC.
-func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+// AllAddresses returns all addresses (primary and non-primary) associated with
+// this NIC.
+func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
+
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
+ // Don't include expired or temporary endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -339,45 +399,66 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
return addrs
}
-// AddSubnet adds a new subnet to n, so that it starts accepting packets
-// targeted at the given address and network protocol.
-func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+// PrimaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var addrs []tcpip.ProtocolAddress
+ for proto, list := range n.primary {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
+
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
+ })
+ }
+ }
+ return addrs
+}
+
+// AddAddressRange adds a range of addresses to n, so that it starts accepting
+// packets targeted at the given addresses and network protocol. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
n.mu.Lock()
- n.subnets = append(n.subnets, subnet)
+ n.addressRanges = append(n.addressRanges, subnet)
n.mu.Unlock()
}
-// RemoveSubnet removes the given subnet from n.
-func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
+// RemoveAddressRange removes the given address range from n.
+func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
n.mu.Lock()
// Use the same underlying array.
- tmp := n.subnets[:0]
- for _, sub := range n.subnets {
+ tmp := n.addressRanges[:0]
+ for _, sub := range n.addressRanges {
if sub != subnet {
tmp = append(tmp, sub)
}
}
- n.subnets = tmp
+ n.addressRanges = tmp
n.mu.Unlock()
}
-// ContainsSubnet reports whether this NIC contains the given subnet.
-func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
- for _, s := range n.Subnets() {
- if s == subnet {
- return true
- }
- }
- return false
-}
-
// Subnets returns the Subnets associated with this NIC.
-func (n *NIC) Subnets() []tcpip.Subnet {
+func (n *NIC) AddressRanges() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
- sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints))
for nid := range n.endpoints {
sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
if err != nil {
@@ -387,19 +468,22 @@ func (n *NIC) Subnets() []tcpip.Subnet {
}
sns = append(sns, sn)
}
- return append(sns, n.subnets...)
+ return append(sns, n.addressRanges...)
}
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
- // Nothing to do if the reference has already been replaced with a
- // different one.
+ // Nothing to do if the reference has already been replaced with a different
+ // one. This happens in the case where 1) this endpoint's ref count hit zero
+ // and was waiting (on the lock) to be removed and 2) the same address was
+ // re-added in the meantime by removing this endpoint from the list and
+ // adding a new one.
if n.endpoints[id] != r {
return
}
- if r.holdsInsertRef {
+ if r.getKind() == permanent {
panic("Reference count dropped to zero before being removed")
}
@@ -418,15 +502,28 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
- r := n.endpoints[NetworkEndpointID{addr}]
- if r == nil || !r.holdsInsertRef {
+func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+ r, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok || r.getKind() != permanent {
return tcpip.ErrBadLocalAddress
}
- r.holdsInsertRef = false
+ r.setKind(permanentExpired)
+ if !r.decRefLocked() {
+ // The endpoint still has references to it.
+ return nil
+ }
- r.decRefLocked()
+ // At this point the endpoint is deleted.
+
+ // If we are removing an IPv6 unicast address, leave the solicited-node
+ // multicast address.
+ if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) {
+ snmc := header.SolicitedNodeAddr(addr)
+ if err := n.leaveGroupLocked(snmc); err != nil {
+ return err
+ }
+ }
return nil
}
@@ -435,7 +532,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.removeAddressLocked(addr)
+ return n.removePermanentAddressLocked(addr)
}
// joinGroup adds a new endpoint for the given multicast address, if none
@@ -444,6 +541,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
n.mu.Lock()
defer n.mu.Unlock()
+ return n.joinGroupLocked(protocol, addr)
+}
+
+// joinGroupLocked adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count. n MUST be locked before
+// joinGroupLocked is called.
+func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -451,13 +555,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
if !ok {
return tcpip.ErrUnknownProtocol
}
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint, false); err != nil {
+ }, NeverPrimaryEndpoint); err != nil {
return err
}
}
@@ -471,6 +575,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
+ return n.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked decrements the count for the given multicast address, and
+// when it reaches zero removes the endpoint for this address. n MUST be locked
+// before leaveGroupLocked is called.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
switch joins {
@@ -479,7 +590,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
case 1:
// This is the last one, clean up.
- if err := n.removeAddressLocked(addr); err != nil {
+ if err := n.removePermanentAddressLocked(addr); err != nil {
return err
}
}
@@ -487,6 +598,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return nil
}
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+ r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remotelinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the physical interface.
@@ -514,29 +632,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
src, dst := netProto.ParseAddresses(vv.First())
- // If the packet is destined to the IPv4 Broadcast address, then make a
- // route to each IPv4 network endpoint and let each endpoint handle the
- // packet.
- if dst == header.IPv4Broadcast {
- // n.endpoints is mutex protected so acquire lock.
- n.mu.RLock()
- for _, ref := range n.endpoints {
- if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
- }
- }
- n.mu.RUnlock()
- return
- }
-
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
}
@@ -559,8 +656,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
n.mu.RUnlock()
- if ok && ref.tryIncRef() {
+ if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, vv)
@@ -599,9 +697,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
- }
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
@@ -615,9 +711,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
- return
- }
if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
@@ -631,7 +724,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
@@ -659,10 +752,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
- return
- }
- if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
return
}
}
@@ -672,9 +762,38 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+// Stack returns the instance of the Stack that owns this NIC.
+func (n *NIC) Stack() *Stack {
+ return n.stack
+}
+
+type networkEndpointKind int32
+
+const (
+ // A permanent endpoint is created by adding a permanent address (vs. a
+ // temporary one) to the NIC. Its reference count is biased by 1 to avoid
+ // removal when no route holds a reference to it. It is removed by explicitly
+ // removing the permanent address from the NIC.
+ permanent networkEndpointKind = iota
+
+ // An expired permanent endoint is a permanent endoint that had its address
+ // removed from the NIC, and it is waiting to be removed once no more routes
+ // hold a reference to it. This is achieved by decreasing its reference count
+ // by 1. If its address is re-added before the endpoint is removed, its type
+ // changes back to permanent and its reference count increases by 1 again.
+ permanentExpired
+
+ // A temporary endpoint is created for spoofing outgoing packets, or when in
+ // promiscuous mode and accepting incoming packets that don't match any
+ // permanent endpoint. Its reference count is not biased by 1 and the
+ // endpoint is removed immediately when no more route holds a reference to
+ // it. A temporary endpoint can be promoted to permanent if its address
+ // is added permanently.
+ temporary
+)
+
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -683,11 +802,34 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
- // holdsInsertRef is protected by the NIC's mutex. It indicates whether
- // the reference count is biased by 1 due to the insertion of the
- // endpoint. It is reset to false when RemoveAddress is called on the
- // NIC.
- holdsInsertRef bool
+ // refs is counting references held for this endpoint. When refs hits zero it
+ // triggers the automatic removal of the endpoint from the NIC.
+ refs int32
+
+ // networkEndpointKind must only be accessed using {get,set}Kind().
+ kind networkEndpointKind
+}
+
+func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
+ return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
+}
+
+func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
+ atomic.StoreInt32((*int32)(&r.kind), int32(kind))
+}
+
+// isValidForOutgoing returns true if the endpoint can be used to send out a
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in spoofing mode.
+func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
+ return r.getKind() != permanentExpired || r.nic.spoofing
+}
+
+// isValidForIncoming returns true if the endpoint can accept an incoming
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in promiscuous mode.
+func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
+ return r.getKind() != permanentExpired || r.nic.promiscuous
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
@@ -699,11 +841,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked.
-func (r *referencedNetworkEndpoint) decRefLocked() {
+// locked. Returns true if the endpoint was removed.
+func (r *referencedNetworkEndpoint) decRefLocked() bool {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpointLocked(r)
+ return true
}
+
+ return false
}
// incRef increments the ref count. It must only be called when the caller is
@@ -728,3 +873,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool {
}
}
}
+
+// stack returns the Stack instance that owns the underlying endpoint.
+func (r *referencedNetworkEndpoint) stack() *Stack {
+ return r.nic.stack
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2037eef9f..9d6157f22 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -15,8 +15,6 @@
package stack
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -109,7 +107,7 @@ type TransportProtocol interface {
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -148,6 +146,19 @@ const (
PacketLoop
)
+// NetworkHeaderParams are the header parameters given as input by the
+// transport endpoint to the network.
+type NetworkHeaderParams struct {
+ // Protocol refers to the transport protocol number.
+ Protocol tcpip.TransportProtocolNumber
+
+ // TTL refers to Time To Live field of the IP-header.
+ TTL uint8
+
+ // TOS refers to TypeOfService or TrafficClass field of the IP-header.
+ TOS uint8
+}
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
@@ -172,7 +183,7 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
@@ -297,6 +308,15 @@ type LinkEndpoint interface {
// IsAttached returns whether a NetworkDispatcher is attached to the
// endpoint.
IsAttached() bool
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // For now, requesting that an endpoint's worker goroutine(s) stop is
+ // implementation specific.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -359,14 +379,6 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// TransportProtocolFactory functions are used by the stack to instantiate
-// transport protocols.
-type TransportProtocolFactory func() TransportProtocol
-
-// NetworkProtocolFactory provides methods to be used by the stack to
-// instantiate network protocols.
-type NetworkProtocolFactory func() NetworkProtocol
-
// UnassociatedEndpointFactory produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can be used
// to write arbitrary packets that include the IP header.
@@ -374,60 +386,6 @@ type UnassociatedEndpointFactory interface {
NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
-var (
- transportProtocols = make(map[string]TransportProtocolFactory)
- networkProtocols = make(map[string]NetworkProtocolFactory)
-
- unassociatedFactory UnassociatedEndpointFactory
-
- linkEPMu sync.RWMutex
- nextLinkEndpointID tcpip.LinkEndpointID = 1
- linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
-)
-
-// RegisterTransportProtocolFactory registers a new transport protocol factory
-// with the stack so that it becomes available to users of the stack. This
-// function is intended to be called by init() functions of the protocols.
-func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
- transportProtocols[name] = p
-}
-
-// RegisterNetworkProtocolFactory registers a new network protocol factory with
-// the stack so that it becomes available to users of the stack. This function
-// is intended to be called by init() functions of the protocols.
-func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
- networkProtocols[name] = p
-}
-
-// RegisterUnassociatedFactory registers a factory to produce endpoints not
-// associated with any particular transport protocol. This function is intended
-// to be called by init() functions of the protocols.
-func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
- unassociatedFactory = f
-}
-
-// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
-// ID that can be used to refer to it.
-func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
- linkEPMu.Lock()
- defer linkEPMu.Unlock()
-
- v := nextLinkEndpointID
- nextLinkEndpointID++
-
- linkEndpoints[v] = linkEP
-
- return v
-}
-
-// FindLinkEndpoint finds the link endpoint associated with the given ID.
-func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
- linkEPMu.RLock()
- defer linkEPMu.RUnlock()
-
- return linkEndpoints[id]
-}
-
// GSOType is the type of GSO segments.
//
// +stateify savable
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 391ab4344..e72373964 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop = PacketLoop
} else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
}
return Route{
@@ -148,12 +150,16 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
func (r *Route) IsResolutionRequired() bool {
- return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
- err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
+func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
@@ -166,6 +172,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
@@ -200,12 +210,24 @@ func (r *Route) Clone() Route {
return *r
}
-// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast.
+// MakeLoopedRoute duplicates the given route with special handling for routes
+// used for sending multicast or broadcast packets. In those cases the
+// multicast/broadcast address is the remote address when sending out, but for
+// incoming (looped) packets it becomes the local address. Similarly, the local
+// interface address that was the local address going out becomes the remote
+// address coming in. This is different to unicast routes where local and
+// remote addresses remain the same as they identify location (local vs remote)
+// not direction (source vs destination).
func (r *Route) MakeLoopedRoute() Route {
l := r.Clone()
- if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
+ if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
l.RemoteLinkAddress = l.LocalLinkAddress
}
return l
}
+
+// Stack returns the instance of the Stack that owns this route.
+func (r *Route) Stack() *Stack {
+ return r.ref.stack()
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index d69162ba1..f67975525 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -17,17 +17,15 @@
//
// For consumers, the only function of interest is New(), everything else is
// provided by the tcpip/public package.
-//
-// For protocol implementers, RegisterTransportProtocolFactory() and
-// RegisterNetworkProtocolFactory() are used to register protocol factories with
-// the stack, which will then be used to instantiate protocol objects when
-// consumers interact with the stack.
package stack
import (
+ "encoding/binary"
"sync"
"time"
+ "golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -45,6 +43,9 @@ const (
resolutionTimeout = 1 * time.Second
// resolutionAttempts is set to the same ARP retries used in Linux.
resolutionAttempts = 3
+
+ // DefaultTOS is the default type of service value for network endpoints.
+ DefaultTOS = 0
)
type transportProtocolState struct {
@@ -350,6 +351,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // unassociatedFactory creates unassociated endpoints. If nil, raw
+ // endpoints are disabled. It is set during Stack creation and is
+ // immutable.
unassociatedFactory UnassociatedEndpointFactory
demux *transportDemuxer
@@ -358,10 +362,6 @@ type Stack struct {
linkAddrCache *linkAddrCache
- // raw indicates whether raw sockets may be created. It is set during
- // Stack creation and is immutable.
- raw bool
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
forwarding bool
@@ -389,10 +389,26 @@ type Stack struct {
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
resumableEndpoints []ResumableEndpoint
+
+ // icmpRateLimiter is a global rate limiter for all ICMP messages generated
+ // by the stack.
+ icmpRateLimiter *ICMPRateLimiter
+
+ // portSeed is a one-time random value initialized at stack startup
+ // and is used to seed the TCP port picking on active connections
+ //
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ portSeed uint32
}
// Options contains optional Stack configuration.
type Options struct {
+ // NetworkProtocols lists the network protocols to enable.
+ NetworkProtocols []NetworkProtocol
+
+ // TransportProtocols lists the transport protocols to enable.
+ TransportProtocols []TransportProtocol
+
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
@@ -406,10 +422,39 @@ type Options struct {
// stack (false).
HandleLocal bool
- // Raw indicates whether raw sockets may be created.
- Raw bool
+ // UnassociatedFactory produces unassociated endpoints raw endpoints.
+ // Raw endpoints are enabled only if this is non-nil.
+ UnassociatedFactory UnassociatedEndpointFactory
}
+// TransportEndpointInfo holds useful information about a transport endpoint
+// which can be queried by monitoring tools.
+//
+// +stateify savable
+type TransportEndpointInfo struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+
+ NetProto tcpip.NetworkProtocolNumber
+ TransProto tcpip.TransportProtocolNumber
+
+ // The following fields are protected by endpoint mu.
+
+ ID TransportEndpointID
+ // BindNICID and bindAddr are set via calls to Bind(). They are used to
+ // reject attempts to send data or connect via a different NIC or
+ // address
+ BindNICID tcpip.NICID
+ BindAddr tcpip.Address
+ // RegisterNICID is the default NICID registered as a side-effect of
+ // connect or datagram write.
+ RegisterNICID tcpip.NICID
+}
+
+// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+// marker interface.
+func (*TransportEndpointInfo) IsEndpointInfo() {}
+
// New allocates a new networking stack with only the requested networking and
// transport protocols configured with default options.
//
@@ -417,7 +462,7 @@ type Options struct {
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
-func New(network []string, transport []string, opts Options) *Stack {
+func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
@@ -433,16 +478,12 @@ func New(network []string, transport []string, opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- raw: opts.Raw,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ portSeed: generateRandUint32(),
}
// Add specified network protocols.
- for _, name := range network {
- netProtoFactory, ok := networkProtocols[name]
- if !ok {
- continue
- }
- netProto := netProtoFactory()
+ for _, netProto := range opts.NetworkProtocols {
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -450,18 +491,14 @@ func New(network []string, transport []string, opts Options) *Stack {
}
// Add specified transport protocols.
- for _, name := range transport {
- transProtoFactory, ok := transportProtocols[name]
- if !ok {
- continue
- }
- transProto := transProtoFactory()
+ for _, transProto := range opts.TransportProtocols {
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
}
- s.unassociatedFactory = unassociatedFactory
+ // Add the factory for unassociated endpoints, if present.
+ s.unassociatedFactory = opts.UnassociatedFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -596,7 +633,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if !s.raw {
+ if s.unassociatedFactory == nil {
return nil, tcpip.ErrNotPermitted
}
@@ -614,12 +651,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
- ep := FindLinkEndpoint(linkEP)
- if ep == nil {
- return tcpip.ErrBadLinkEndpoint
- }
-
+func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -632,40 +664,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
s.nics[id] = n
if enabled {
- n.attachLinkEndpoint()
+ return n.enable()
}
return nil
}
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
-func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, true, false)
+func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, true, false)
}
// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, false)
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, false)
}
// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, true)
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, false, false)
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, false, false)
}
// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, false, false)
+func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, false, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -679,9 +711,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
- nic.attachLinkEndpoint()
-
- return nil
+ return nic.enable()
}
// CheckNIC checks if a NIC is usable.
@@ -696,14 +726,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
}
// NICSubnets returns a map of NICIDs to their associated subnets.
-func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
+func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
defer s.mu.RUnlock()
nics := map[tcpip.NICID][]tcpip.Subnet{}
for id, nic := range s.nics {
- nics[id] = append(nics[id], nic.Subnets()...)
+ nics[id] = append(nics[id], nic.AddressRanges()...)
}
return nics
}
@@ -739,7 +769,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.Addresses(),
+ ProtocolAddresses: nic.PrimaryAddresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
@@ -804,71 +834,79 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return nic.AddAddress(protocolAddress, peb)
}
-// AddSubnet adds a subnet range to the specified NIC.
-func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+// AddAddressRange adds a range of addresses to the specified NIC. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- nic.AddSubnet(protocol, subnet)
+ nic.AddAddressRange(protocol, subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
-// RemoveSubnet removes the subnet range from the specified NIC.
-func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+// RemoveAddressRange removes the range of addresses from the specified NIC.
+func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- nic.RemoveSubnet(subnet)
+ nic.RemoveAddressRange(subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
-// ContainsSubnet reports whether the specified NIC contains the specified
-// subnet.
-func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- return nic.ContainsSubnet(subnet), nil
+ return nic.RemoveAddress(addr)
}
- return false, tcpip.ErrUnknownNICID
+ return tcpip.ErrUnknownNICID
}
-// RemoveAddress removes an existing network-layer address from the specified
-// NIC.
-func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+// AllAddresses returns a map of NICIDs to their protocol addresses (primary
+// and non-primary).
+func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.RemoveAddress(addr)
+ nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
+ for id, nic := range s.nics {
+ nics[id] = nic.AllAddresses()
}
-
- return tcpip.ErrUnknownNICID
+ return nics
}
-// GetMainNICAddress returns the first primary address (and the subnet that
-// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
-// address if no primary addresses exist. Returns an error if the NIC doesn't
-// exist or has no endpoints.
+// GetMainNICAddress returns the first primary address and prefix for the given
+// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
+// value if the NIC doesn't have a primary address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.getMainNICAddress(protocol)
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ for _, a := range nic.PrimaryAddresses() {
+ if a.Protocol == protocol {
+ return a.AddressWithPrefix, nil
+ }
+ }
+ return tcpip.AddressWithPrefix{}, nil
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
@@ -895,7 +933,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
} else {
for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
if nic, ok := s.nics[route.NIC]; ok {
@@ -1035,73 +1073,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- }
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerRawEndpoint(netProto, transProto, ep)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProto, transProto, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
- }
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
// RegisterRestoredEndpoint records e as an endpoint that has been restored on
@@ -1215,3 +1207,49 @@ func (s *Stack) IPTables() iptables.IPTables {
func (s *Stack) SetIPTables(ipt iptables.IPTables) {
s.tables = ipt
}
+
+// ICMPLimit returns the maximum number of ICMP messages that can be sent
+// in one second.
+func (s *Stack) ICMPLimit() rate.Limit {
+ return s.icmpRateLimiter.Limit()
+}
+
+// SetICMPLimit sets the maximum number of ICMP messages that be sent
+// in one second.
+func (s *Stack) SetICMPLimit(newLimit rate.Limit) {
+ s.icmpRateLimiter.SetLimit(newLimit)
+}
+
+// ICMPBurst returns the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) ICMPBurst() int {
+ return s.icmpRateLimiter.Burst()
+}
+
+// SetICMPBurst sets the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) SetICMPBurst(burst int) {
+ s.icmpRateLimiter.SetBurst(burst)
+}
+
+// AllowICMPMessage returns true if we the rate limiter allows at least one
+// ICMP message to be sent at this instant.
+func (s *Stack) AllowICMPMessage() bool {
+ return s.icmpRateLimiter.Allow()
+}
+
+// PortSeed returns a 32 bit value that can be used as a seed value for port
+// picking.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) PortSeed() uint32 {
+ return s.portSeed
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 137c6183e..10fd1065f 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct {
prefixLen int
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- linkEP stack.LinkEndpoint
+ ep stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
@@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
+ return f.ep.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -116,10 +116,10 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto
}
func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.linkEP.Capabilities()
+ return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -128,7 +128,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
b := hdr.Prepend(fakeNetHeaderLen)
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
- b[2] = byte(protocol)
+ b[2] = byte(params.Protocol)
if loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(payload.Views()))
@@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return nil
}
- return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
+ return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
@@ -181,18 +181,22 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
+func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
+ return f.packetCount[int(intfAddr)%len(f.packetCount)]
+}
+
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
- linkEP: linkEP,
+ ep: ep,
}, nil
}
@@ -218,12 +222,18 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeNetFactory() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -241,7 +251,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -251,7 +261,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -261,7 +271,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -270,7 +280,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber-1, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -280,7 +290,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -289,16 +299,75 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ return err
}
defer r.Release()
+ return send(r, payload)
+}
+func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Error("WritePacket failed:", err)
+ return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS})
+}
+
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := sendTo(s, addr, payload); err != nil {
+ t.Error("sendTo failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("sendTo packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := send(r, payload); err != nil {
+ t.Error("send failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("send packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := send(r, payload); gotErr != wantErr {
+ t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
+ t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte) + 1
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we do NOT expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte)
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
+ t.Helper()
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if got := fakeNet.PacketCount(localAddrByte); got != want {
+ t.Errorf("receive packet count: got = %d, want %d", got, want)
}
}
@@ -306,9 +375,11 @@ func TestNetworkSend(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and one
// address: 1. The route table sends all packets through the only
// existing nic.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
}
@@ -325,20 +396,19 @@ func TestNetworkSend(t *testing.T) {
}
// Make sure that the link-layer endpoint received the outbound packet.
- sendTo(t, s, "\x03", nil)
- if c := linkEP.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x03", ep, nil)
}
func TestNetworkSendMultiRoute(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -350,8 +420,8 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -382,18 +452,10 @@ func TestNetworkSendMultiRoute(t *testing.T) {
}
// Send a packet to an odd destination.
- sendTo(t, s, "\x05", nil)
-
- if c := linkEP1.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x05", ep1, nil)
// Send a packet to an even destination.
- sendTo(t, s, "\x06", nil)
-
- if c := linkEP2.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x06", ep2, nil)
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
@@ -424,10 +486,12 @@ func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id1, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -439,8 +503,8 @@ func TestRoutes(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -498,58 +562,71 @@ func TestRoutes(t *testing.T) {
}
func TestAddressRemoval(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
- // Remove the address, then check that packet doesn't get delivered
- // anymore.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
}
-func TestDelayedRemovalDueToRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+func TestAddressRemovalWithRouteHeld(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
}
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
-
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
if err != nil {
@@ -558,58 +635,239 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-
- // Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 2 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
- // Remove the address, then check that packet is still deliverable
- // because the route is keeping the address alive.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
+}
+
+func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+ t.Helper()
+ info, ok := s.NICInfo()[nicid]
+ if !ok {
+ t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ }
+ if len(addr) == 0 {
+ // No address given, verify that there is no address assigned to the NIC.
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ }
+ }
+ return
+ }
+ // Address given, verify the address is assigned to the NIC and no other
+ // address is.
+ found := false
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber {
+ if a.AddressWithPrefix.Address == addr {
+ found = true
+ } else {
+ t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
+ }
+ }
+ }
+ if !found {
+ t.Errorf("verify address: couldn't find %s on the NIC", addr)
+ }
+}
+
+func TestEndpointExpiration(t *testing.T) {
+ const (
+ localAddrByte byte = 0x01
+ remoteAddr tcpip.Address = "\x03"
+ noAddr tcpip.Address = ""
+ nicid tcpip.NICID = 1
+ )
+ localAddr := tcpip.Address([]byte{localAddrByte})
+
+ for _, promiscuous := range []bool{true, false} {
+ for _, spoofing := range []bool{true, false} {
+ t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+ buf[0] = localAddrByte
+
+ if promiscuous {
+ if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ }
+
+ if spoofing {
+ if err := s.SetSpoofing(nicid, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ }
+
+ // 1. No Address yet, send should only work for spoofing, receive for
+ // promiscuous mode.
+ //-----------------------
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
- // Release the route, then check that packet is not deliverable anymore.
- r.Release()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+ // 2. Add Address, everything should work.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 3. Remove the address, send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 4. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 5. Take a reference to the endpoint by getting a route. Verify that
+ // we can still send/receive, including sending using the route.
+ //-----------------------
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 6. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 7. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 8. Remove the route, sendTo/recv should still work.
+ //-----------------------
+ r.Release()
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 9. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+ })
+ }
}
}
func TestPromiscuousMode(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -627,22 +885,15 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
@@ -655,25 +906,24 @@ func TestPromiscuousMode(t *testing.T) {
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
-func TestAddressSpoofing(t *testing.T) {
- srcAddr := tcpip.Address("\x01")
- dstAddr := tcpip.Address("\x02")
+func TestSpoofingWithAddress(t *testing.T) {
+ localAddr := tcpip.Address("\x01")
+ nonExistentLocalAddr := tcpip.Address("\x02")
+ dstAddr := tcpip.Address("\x03")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
@@ -687,61 +937,217 @@ func TestAddressSpoofing(t *testing.T) {
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ testSendTo(t, s, dstAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // FindRoute should also work with a local address that exists on the NIC.
+ r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != localAddr {
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet using the route works.
+ testSend(t, r, ep, nil)
+}
+
+func TestSpoofingNoAddress(t *testing.T) {
+ nonExistentLocalAddr := tcpip.Address("\x01")
+ dstAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
+ // Sending a packet fails.
+ testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
if err := s.SetSpoofing(1, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- if r.LocalAddress != srcAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
}
+ // Sending a packet works.
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
}
-func TestBroadcastNeedsNoRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+func verifyRoute(gotRoute, wantRoute stack.Route) error {
+ if gotRoute.LocalAddress != wantRoute.LocalAddress {
+ return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
+ }
+ if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
+ return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
+ }
+ if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ }
+ if gotRoute.NextHop != wantRoute.NextHop {
+ return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
+ }
+ return nil
+}
+
+func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{})
// If there is no endpoint, it won't work.
if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
- if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err)
+ protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
+ if err := s.AddProtocolAddress(1, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if r.LocalAddress != header.IPv4Any {
- t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any)
+ // If the NIC doesn't exist, it won't work.
+ if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
+ t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
+}
- if r.RemoteAddress != header.IPv4Broadcast {
- t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast)
+func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
+ defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
+ // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
+ nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
+ nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
+ nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
+ nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01")
+
+ // Create a new stack with two NICs.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+ if err := s.CreateNIC(2, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+ nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err)
}
- // If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
+ t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err)
+ }
+
+ // Set the initial route table.
+ rt := []tcpip.Route{
+ {Destination: nic1Addr.Subnet(), NIC: 1},
+ {Destination: nic2Addr.Subnet(), NIC: 2},
+ {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2},
+ {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1},
+ }
+ s.SetRouteTable(rt)
+
+ // When an interface is given, the route for a broadcast goes through it.
+ r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // When an interface is not given, it consults the route table.
+ // 1. Case: Using the default route.
+ r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // 2. Case: Having an explicit route for broadcast will select that one.
+ rt = append(
+ []tcpip.Route{
+ {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ },
+ rt...,
+ )
+ s.SetRouteTable(rt)
+ r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -781,10 +1187,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
{"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
} {
t.Run(tc.name, func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -835,12 +1243,14 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
-// Set the subnet, then check that packet is delivered.
-func TestSubnetAcceptsMatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+// Add a range of addresses, then check that a packet is delivered.
+func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -856,29 +1266,59 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+}
+
+func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
+ t.Helper()
+
+ // Loop over all addresses and check them.
+ numOfAddresses := 1 << uint(8-subnet.Prefix())
+ if numOfAddresses < 1 || numOfAddresses > 255 {
+ t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
+ }
+
+ addrBytes := []byte(subnet.ID())
+ for i := 0; i < numOfAddresses; i++ {
+ addr := tcpip.Address(addrBytes)
+ wantNicID := nicID
+ // The subnet and broadcast addresses are skipped.
+ if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() {
+ wantNicID = 0
+ }
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID)
+ }
+ addrBytes[0]++
+ }
+
+ // Trying the next address should always fail since it is outside the range.
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
}
}
-// Set the subnet, then check that CheckLocalAddress returns the correct NIC.
+// Set a range of addresses, then remove it again, and check at each step that
+// CheckLocalAddress returns the correct NIC for each address or zero if not
+// existent.
func TestCheckLocalAddressForSubnet(t *testing.T) {
const nicID tcpip.NICID = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -891,39 +1331,34 @@ func TestCheckLocalAddressForSubnet(t *testing.T) {
}
subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
-
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
- }
- // Loop over all subnet addresses and check them.
- numOfAddresses := 1 << uint(8-subnet.Prefix())
- if numOfAddresses < 1 || numOfAddresses > 255 {
- t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
- }
- addr := []byte(subnet.ID())
- for i := 0; i < numOfAddresses; i++ {
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID)
- }
- addr[0]++
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
+
+ if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- // Trying the next address should fail since it is outside the subnet range.
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0)
+ testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */)
+
+ if err := s.RemoveAddressRange(nicID, subnet); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
}
+
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
}
-// Set destination outside the subnet, then check it doesn't get delivered.
-func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+// Set a range of addresses, then send a packet to a destination outside the
+// range and then check it doesn't get delivered.
+func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -939,23 +1374,23 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
- }
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
func TestNetworkOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{},
+ })
// Try an unsupported network protocol.
if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -994,44 +1429,53 @@ func TestNetworkOptions(t *testing.T) {
}
}
-func TestSubnetAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool {
+ ranges, ok := s.NICAddressRanges()[id]
+ if !ok {
+ return false
+ }
+ for _, r := range ranges {
+ if r == addrRange {
+ return true
+ }
+ }
+ return false
+}
+
+func TestAddresRangeAddRemove(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
addr := tcpip.Address("\x01\x01\x01\x01")
mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- subnet, err := tcpip.NewSubnet(addr, mask)
+ addrRange, err := tcpip.NewSubnet(addr, mask)
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
+ if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if !contained {
- t.Fatal("got s.ContainsSubnet(...) = false, want = true")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
- if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatal("RemoveSubnet failed:", err)
+ if err := s.RemoveAddressRange(1, addrRange); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
}
@@ -1042,9 +1486,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
// Insert <canBe> primary and <never> never-primary addresses.
@@ -1082,20 +1528,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
if len(primaryAddrAdded) == 0 {
- // No primary addresses present, expect an error.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress)
+ // No primary addresses present.
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
}
} else {
- // At least one primary address was added, expect a valid
- // address and prefixLen.
- gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok {
- t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded)
+ // At least one primary address was added, verify the returned
+ // address is in the list of primary addresses we added.
+ if _, ok := primaryAddrAdded[gotAddr]; !ok {
+ t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
}
}
})
@@ -1107,9 +1553,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1134,19 +1582,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
t.Fatal("GetMainNICAddress failed:", err)
- } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix)
+ }
+ if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
+ t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- // Check that we get an error after removal.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress)
+ // Check that we get no address after removal.
+ gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
})
}
@@ -1161,8 +1615,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address {
}
func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
+ t.Helper()
+
if len(gotAddresses) != len(expectedAddresses) {
- t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses))
+ t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
}
sort.Slice(gotAddresses, func(i, j int) bool {
@@ -1182,9 +1638,11 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1201,15 +1659,17 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1233,15 +1693,17 @@ func TestAddProtocolAddress(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1262,15 +1724,17 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1297,15 +1761,17 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestNICStats(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed: ", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
@@ -1321,7 +1787,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -1332,10 +1798,12 @@ func TestNICStats(t *testing.T) {
payload := buffer.NewView(10)
// Write a packet out via the address for NIC 1
- sendTo(t, s, "\x01", payload)
- want := uint64(linkEP1.Drain())
+ if err := sendTo(s, "\x01", payload); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
+ want := uint64(ep1.Drain())
if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
+ t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
@@ -1346,19 +1814,21 @@ func TestNICStats(t *testing.T) {
func TestNICForwarding(t *testing.T) {
// Create a stack with the fake network protocol, two NICs, each with
// an address.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
s.SetForwarding(true)
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress #1 failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -1377,10 +1847,10 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to address 3.
buf := buffer.NewView(30)
buf[0] = 3
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
select {
- case <-linkEP2.C:
+ case <-ep2.C:
default:
t.Fatal("Packet not forwarded")
}
@@ -1394,9 +1864,3 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
-
-func init() {
- stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
- })
-}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cf8a6d129..92267ce4d 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -35,25 +35,109 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]TransportEndpoint
+ endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
+type endpointsByNic struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ }
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if isMulticastOrBroadcast(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNic.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+
+ if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
+ // There was already a bind.
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ }
+
+ // This is a new binding.
+ multiPortEp := &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.reuse = reusePort
+ epsByNic.endpoints[bindToDevice] = multiPortEp
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+}
+
+// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
+func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t) {
+ delete(epsByNic.endpoints, bindToDevice)
+ }
+ return len(epsByNic.endpoints) == 0
+}
+
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- e, ok := eps.endpoints[id]
+ epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
- if multiPortEp, ok := e.(*multiPortEndpoint); ok {
- if !multiPortEp.unregisterEndpoint(ep) {
- return
- }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
}
delete(eps.endpoints, id)
}
@@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
@@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
@@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
-// the same pair of address and port.
+// the same pair of address and port. endpointsArr always has at least one
+// element.
type multiPortEndpoint struct {
mu sync.RWMutex
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
- // seed is a random secret for a jenkins hash.
- seed uint32
+ // reuse indicates if more than one endpoint is allowed.
+ reuse bool
}
// reciprocalScale scales a value into range [0, n).
@@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpointsArr) == 1 {
+ return mpep.endpointsArr[0]
+ }
payload := []byte{
byte(id.LocalPort),
@@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
byte(id.RemotePort >> 8),
}
- h := jenkins.Sum32(ep.seed)
+ h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
- return ep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
+ return mpep.endpointsArr[idx]
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast or multicast datagram, deliver the datagram to all
- // endpoints managed by ep.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
- }
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.mu.RLock()
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy except for
+ // the final one.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
}
- } else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
-}
-
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- // A new endpoint is added into endpointsArr and its index there is
- // saved in endpointsMap. This will allows to remove endpoint from
- // the array fast.
+ if len(ep.endpointsArr) > 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if !ep.reuse || !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ // A new endpoint is added into endpointsArr and its index there is saved in
+ // endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+ return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
@@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
return true
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
+ // TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return nil
+ return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
- var multiPortEp *multiPortEndpoint
- if _, ok := eps.endpoints[id]; ok {
- if !reusePort {
- return tcpip.ErrPortInUse
- }
- multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
- if !ok {
- return tcpip.ErrPortInUse
- }
+ if epsByNic, ok := eps.endpoints[id]; ok {
+ // There was already a binding.
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
- if reusePort {
- if multiPortEp == nil {
- multiPortEp = &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.seed = rand.Uint32()
- eps.endpoints[id] = multiPortEp
- }
-
- multiPortEp.singleRegisterEndpoint(ep)
-
- return nil
+ // This is a new binding.
+ epsByNic := &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
}
- eps.endpoints[id] = ep
+ eps.endpoints[id] = epsByNic
- return nil
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep)
+ eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
@@ -265,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return false
}
- // If a sender bound to the Loopback interface sends a broadcast,
- // that broadcast must not be delivered to the sender.
- if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
- return false
- }
-
- // If the packet is a broadcast, then find all matching transport endpoints.
- // Otherwise, try to find a single matching transport endpoint.
- destEps := make([]TransportEndpoint, 0, 1)
eps.mu.RLock()
- if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
- for epID, endpoint := range eps.endpoints {
- if epID.LocalPort == id.LocalPort {
- destEps = append(destEps, endpoint)
- }
- }
+ // Determine which transport endpoint or endpoints to deliver this packet to.
+ // If the packet is a broadcast or multicast, then find all matching
+ // transport endpoints.
+ var destEps []*endpointsByNic
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ destEps = d.findAllEndpointsLocked(eps, vv, id)
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
@@ -299,7 +363,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.handlePacket(r, id, vv)
}
return true
@@ -331,7 +395,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -348,15 +412,16 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
}
// Deliver the packet.
- ep.HandleControlPacket(id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, vv)
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the local address.
@@ -364,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the remote part.
@@ -372,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
+ return matchedEPs
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given
+// id.
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
+ if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 {
+ return matchedEPs[0]
+ }
return nil
}
@@ -418,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
}
}
}
+
+func isMulticastOrBroadcast(addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..210233dc0
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testPort = 4096
+)
+
+type testContext struct {
+ t *testing.T
+ linkEPs map[string]*channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+}
+
+// newDualTestContextMultiNic creates the testing context and also linkEpNames
+// named NICs.
+func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ linkEPs := make(map[string]*channel.Endpoint)
+ for i, linkEpName := range linkEpNames {
+ channelEP := channel.New(256, mtu, "")
+ nicid := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ linkEPs[linkEpName] = channelEP
+
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %v", err)
+ }
+
+ if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %v", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEPs: linkEPs,
+ }
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+// TestReuseBindToDevice injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse int
+ bindToDevice string
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantedDistribution is the wanted ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantedDistributions map[string][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{0, "dev0"},
+ endpointSockopts{0, "dev1"},
+ endpointSockopts{0, "dev2"},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ "dev0": []float64{1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ "dev1": []float64{0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ "dev2": []float64{0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ "dev999": []float64{0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ for device, wantedDistribution := range test.wantedDistributions {
+ t.Run(device, func(t *testing.T) {
+ var devices []string
+ for d := range test.wantedDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
+ if err := ep.SetSockOpt(reusePortOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload,
+ &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort},
+ device)
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantedRatio := wantedDistribution[i]
+ wantedRecv := wantedRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantedRatio) > 0.05 {
+ t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5335897f5..86c62be25 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -38,9 +38,8 @@ const (
// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
// use it.
type fakeTransportEndpoint struct {
- id stack.TransportEndpointID
+ stack.TransportEndpointInfo
stack *stack.Stack
- netProto tcpip.NetworkProtocolNumber
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
@@ -49,8 +48,16 @@ type fakeTransportEndpoint struct {
acceptQueue []fakeTransportEndpoint
}
-func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto}
+func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
+ return &f.TransportEndpointInfo
+}
+
+func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
+ return nil
+}
+
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto}
}
func (f *fakeTransportEndpoint) Close() {
@@ -65,17 +72,17 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil {
+ if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil {
return 0, nil, err
}
@@ -91,6 +98,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -121,8 +133,8 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
// Try to register so that we can start receiving packets.
- f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
+ f.ID.RemoteAddress = addr.Addr
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -163,7 +175,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false,
+ false, /* reuse */
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -184,9 +197,11 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
f.proto.packetCount++
if f.acceptQueue != nil {
f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- id: id,
- stack: f.stack,
- netProto: f.netProto,
+ stack: f.stack,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
proto: f.proto,
peerAddr: r.RemoteAddress,
route: r.Clone(),
@@ -251,7 +266,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
return true
}
@@ -277,10 +292,17 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeTransFactory() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+}
+
func TestTransportReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -340,9 +362,12 @@ func TestTransportReceive(t *testing.T) {
}
func TestTransportControlReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -408,9 +433,12 @@ func TestTransportControlReceive(t *testing.T) {
}
func TestTransportSend(t *testing.T) {
- id, _ := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -452,7 +480,10 @@ func TestTransportSend(t *testing.T) {
}
func TestTransportOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
// Try an unsupported transport protocol.
if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -493,20 +524,23 @@ func TestTransportOptions(t *testing.T) {
}
func TestTransportForwarding(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
s.SetForwarding(true)
// TODO(b/123449044): Change this to a channel NIC.
- id1 := loopback.New()
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := loopback.New()
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatalf("CreateNIC #1 failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress #1 failed: %v", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatalf("CreateNIC #2 failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -545,7 +579,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- linkEP2.Inject(fakeNetNumber, req.ToVectorisedView())
+ ep2.Inject(fakeNetNumber, req.ToVectorisedView())
aep, _, err := ep.Accept()
if err != nil || aep == nil {
@@ -559,7 +593,7 @@ func TestTransportForwarding(t *testing.T) {
var p channel.PacketInfo
select {
- case p = <-linkEP2.C:
+ case p = <-ep2.C:
default:
t.Fatal("Response packet not forwarded")
}
@@ -571,9 +605,3 @@ func TestTransportForwarding(t *testing.T) {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
-
-func init() {
- stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
- return &fakeTransportProtocol{}
- })
-}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 043dd549b..678a94616 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -57,6 +57,9 @@ type Error struct {
// String implements fmt.Stringer.String.
func (e *Error) String() string {
+ if e == nil {
+ return "<nil>"
+ }
return e.msg
}
@@ -219,6 +222,15 @@ func (s *Subnet) Mask() AddressMask {
return s.mask
}
+// Broadcast returns the subnet's broadcast address.
+func (s *Subnet) Broadcast() Address {
+ addr := []byte(s.address)
+ for i := range addr {
+ addr[i] |= ^s.mask[i]
+ }
+ return Address(addr)
+}
+
// NICID is a number that uniquely identifies a NIC.
type NICID int32
@@ -252,31 +264,34 @@ type FullAddress struct {
Port uint16
}
-// Payload provides an interface around data that is being sent to an endpoint.
-// This allows the endpoint to request the amount of data it needs based on
-// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
-type Payload interface {
- // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
- Get(size int) ([]byte, *Error)
-
- // Size returns the payload size.
- Size() int
+// Payloader is an interface that provides data.
+//
+// This interface allows the endpoint to request the amount of data it needs
+// based on internal buffers without exposing them.
+type Payloader interface {
+ // FullPayload returns all available bytes.
+ FullPayload() ([]byte, *Error)
+
+ // Payload returns a slice containing at most size bytes.
+ Payload(size int) ([]byte, *Error)
}
-// SlicePayload implements Payload on top of slices for convenience.
+// SlicePayload implements Payloader for slices.
+//
+// This is typically used for tests.
type SlicePayload []byte
-// Get implements Payload.
-func (s SlicePayload) Get(size int) ([]byte, *Error) {
- if size > s.Size() {
- size = s.Size()
- }
- return s[:size], nil
+// FullPayload implements Payloader.FullPayload.
+func (s SlicePayload) FullPayload() ([]byte, *Error) {
+ return s, nil
}
-// Size implements Payload.
-func (s SlicePayload) Size() int {
- return len(s)
+// Payload implements Payloader.Payload.
+func (s SlicePayload) Payload(size int) ([]byte, *Error) {
+ if size > len(s) {
+ size = len(s)
+ }
+ return s[:size], nil
}
// A ControlMessages contains socket control messages for IP sockets.
@@ -329,7 +344,7 @@ type Endpoint interface {
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
- Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error)
+ Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
@@ -389,6 +404,10 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptInt sets a socket option, for simple cases where a value
+ // has the int type.
+ SetSockOptInt(opt SockOpt, v int) *Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
@@ -410,6 +429,26 @@ type Endpoint interface {
// IPTables returns the iptables for this endpoint's stack.
IPTables() (iptables.IPTables, error)
+
+ // Info returns a copy to the transport endpoint info.
+ Info() EndpointInfo
+
+ // Stats returns a reference to the endpoint stats.
+ Stats() EndpointStats
+}
+
+// EndpointInfo is the interface implemented by each endpoint info struct.
+type EndpointInfo interface {
+ // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+ // marker interface.
+ IsEndpointInfo()
+}
+
+// EndpointStats is the interface implemented by each endpoint stats struct.
+type EndpointStats interface {
+ // IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+ // marker interface.
+ IsEndpointStats()
}
// WriteOptions contains options for Endpoint.Write.
@@ -423,16 +462,33 @@ type WriteOptions struct {
// EndOfRecord has the same semantics as Linux's MSG_EOR.
EndOfRecord bool
+
+ // Atomic means that all data fetched from Payloader must be written to the
+ // endpoint. If Atomic is false, then data fetched from the Payloader may be
+ // discarded if available endpoint buffer space is unsufficient.
+ Atomic bool
}
// SockOpt represents socket options which values have the int type.
type SockOpt int
const (
- // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
- // unread bytes in the input buffer should be returned.
+ // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption SockOpt = iota
+ // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the send buffer size option.
+ SendBufferSizeOption
+
+ // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the receive buffer size option.
+ ReceiveBufferSizeOption
+
+ // SendQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the output buffer should be returned.
+ SendQueueSizeOption
+
// TODO(b/137664753): convert all int socket options to be handled via
// GetSockOptInt.
)
@@ -441,18 +497,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
-// buffer size option.
-type SendBufferSizeOption int
-
-// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
-// receive buffer size option.
-type ReceiveBufferSizeOption int
-
-// SendQueueSizeOption is used in GetSockOpt to specify that the number of
-// unread bytes in the output buffer should be returned.
-type SendQueueSizeOption int
-
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
@@ -474,6 +518,10 @@ type ReuseAddressOption int
// to be bound to an identical socket address.
type ReusePortOption int
+// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
+// should bind only on a specific NIC.
+type BindToDeviceOption string
+
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
@@ -525,6 +573,12 @@ type ModerateReceiveBufferOption bool
// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
type MaxSegOption int
+// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
+// limit value for unicast messages. The default is protocol specific.
+//
+// A zero value indicates the default.
+type TTLOption uint8
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
@@ -566,6 +620,18 @@ type OutOfBandInlineOption int
// datagram sockets are allowed to send packets to a broadcast address.
type BroadcastOption int
+// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify
+// a default TTL.
+type DefaultTTLOption uint8
+
+// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
+// for all subsequent outgoing IPv4 packets from the endpoint.
+type IPv4TOSOption uint8
+
+// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
+// for all subsequent outgoing IPv6 packets from the endpoint.
+type IPv6TrafficClassOption uint8
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
@@ -581,7 +647,7 @@ type Route struct {
}
// String implements the fmt.Stringer interface.
-func (r *Route) String() string {
+func (r Route) String() string {
var out strings.Builder
fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
@@ -591,9 +657,6 @@ func (r *Route) String() string {
return out.String()
}
-// LinkEndpointID represents a data link layer endpoint.
-type LinkEndpointID uint64
-
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
@@ -720,6 +783,10 @@ type ICMPv4SentPacketStats struct {
// Dropped is the total number of ICMPv4 packets dropped due to link
// layer errors.
Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
}
// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
@@ -738,6 +805,10 @@ type ICMPv6SentPacketStats struct {
// Dropped is the total number of ICMPv6 packets dropped due to link
// layer errors.
Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
}
// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
@@ -790,6 +861,14 @@ type IPStats struct {
// OutgoingPacketErrors is the total number of IP packets which failed
// to write to a link-layer endpoint.
OutgoingPacketErrors *StatCounter
+
+ // MalformedPacketsReceived is the total number of IP Packets that were
+ // dropped due to the IP packet header failing validation checks.
+ MalformedPacketsReceived *StatCounter
+
+ // MalformedFragmentsReceived is the total number of IP Fragments that were
+ // dropped due to the fragment failing validation checks.
+ MalformedFragmentsReceived *StatCounter
}
// TCPStats collects TCP-specific stats.
@@ -836,6 +915,9 @@ type TCPStats struct {
// SegmentsSent is the number of TCP segments sent.
SegmentsSent *StatCounter
+ // SegmentSendErrors is the number of TCP segments failed to be sent.
+ SegmentSendErrors *StatCounter
+
// ResetsSent is the number of TCP resets sent.
ResetsSent *StatCounter
@@ -888,6 +970,9 @@ type UDPStats struct {
// PacketsSent is the number of UDP datagrams sent via sendUDP.
PacketsSent *StatCounter
+
+ // PacketSendErrors is the number of datagrams failed to be sent.
+ PacketSendErrors *StatCounter
}
// Stats holds statistics about the networking stack.
@@ -898,7 +983,7 @@ type Stats struct {
// stack that were for an unknown or unsupported protocol.
UnknownProtocolRcvdPackets *StatCounter
- // MalformedRcvPackets is the number of packets received by the stack
+ // MalformedRcvdPackets is the number of packets received by the stack
// that were deemed malformed.
MalformedRcvdPackets *StatCounter
@@ -918,18 +1003,95 @@ type Stats struct {
UDP UDPStats
}
+// ReceiveErrors collects packet receive errors within transport endpoint.
+type ReceiveErrors struct {
+ // ReceiveBufferOverflow is the number of received packets dropped
+ // due to the receive buffer being full.
+ ReceiveBufferOverflow StatCounter
+
+ // MalformedPacketsReceived is the number of incoming packets
+ // dropped due to the packet header being in a malformed state.
+ MalformedPacketsReceived StatCounter
+
+ // ClosedReceiver is the number of received packets dropped because
+ // of receiving endpoint state being closed.
+ ClosedReceiver StatCounter
+}
+
+// SendErrors collects packet send errors within the transport layer for
+// an endpoint.
+type SendErrors struct {
+ // SendToNetworkFailed is the number of packets failed to be written to
+ // the network endpoint.
+ SendToNetworkFailed StatCounter
+
+ // NoRoute is the number of times we failed to resolve IP route.
+ NoRoute StatCounter
+
+ // NoLinkAddr is the number of times we failed to resolve ARP.
+ NoLinkAddr StatCounter
+}
+
+// ReadErrors collects segment read errors from an endpoint read call.
+type ReadErrors struct {
+ // ReadClosed is the number of received packet drops because the endpoint
+ // was shutdown for read.
+ ReadClosed StatCounter
+
+ // InvalidEndpointState is the number of times we found the endpoint state
+ // to be unexpected.
+ InvalidEndpointState StatCounter
+}
+
+// WriteErrors collects packet write errors from an endpoint write call.
+type WriteErrors struct {
+ // WriteClosed is the number of packet drops because the endpoint
+ // was shutdown for write.
+ WriteClosed StatCounter
+
+ // InvalidEndpointState is the number of times we found the endpoint state
+ // to be unexpected.
+ InvalidEndpointState StatCounter
+
+ // InvalidArgs is the number of times invalid input arguments were
+ // provided for endpoint Write call.
+ InvalidArgs StatCounter
+}
+
+// TransportEndpointStats collects statistics about the endpoint.
+type TransportEndpointStats struct {
+ // PacketsReceived is the number of successful packet receives.
+ PacketsReceived StatCounter
+
+ // PacketsSent is the number of successful packet sends.
+ PacketsSent StatCounter
+
+ // ReceiveErrors collects packet receive errors within transport layer.
+ ReceiveErrors ReceiveErrors
+
+ // ReadErrors collects packet read errors from an endpoint read call.
+ ReadErrors ReadErrors
+
+ // SendErrors collects packet send errors within the transport layer.
+ SendErrors SendErrors
+
+ // WriteErrors collects packet write errors from an endpoint write call.
+ WriteErrors WriteErrors
+}
+
+// IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+// marker interface.
+func (*TransportEndpointStats) IsEndpointStats() {}
+
func fillIn(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
- switch v.Kind() {
- case reflect.Ptr:
- if s := v.Addr().Interface().(**StatCounter); *s == nil {
- *s = &StatCounter{}
+ if s, ok := v.Addr().Interface().(**StatCounter); ok {
+ if *s == nil {
+ *s = new(StatCounter)
}
- case reflect.Struct:
+ } else {
fillIn(v)
- default:
- panic(fmt.Sprintf("unexpected type %s", v.Type()))
}
}
}
@@ -940,6 +1102,26 @@ func (s Stats) FillIn() Stats {
return s
}
+// Clone returns a copy of the TransportEndpointStats by atomically reading
+// each field.
+func (src *TransportEndpointStats) Clone() TransportEndpointStats {
+ var dst TransportEndpointStats
+ clone(reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src).Elem())
+ return dst
+}
+
+func clone(dst reflect.Value, src reflect.Value) {
+ for i := 0; i < dst.NumField(); i++ {
+ d := dst.Field(i)
+ s := src.Field(i)
+ if c, ok := s.Addr().Interface().(*StatCounter); ok {
+ d.Addr().Interface().(*StatCounter).IncrementBy(c.Value())
+ } else {
+ clone(d, s)
+ }
+ }
+}
+
// String implements the fmt.Stringer interface.
func (a Address) String() string {
switch len(a) {
@@ -1065,6 +1247,47 @@ func (a AddressWithPrefix) String() string {
return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen)
}
+// Subnet converts the address and prefix into a Subnet value and returns it.
+func (a AddressWithPrefix) Subnet() Subnet {
+ addrLen := len(a.Address)
+ if a.PrefixLen <= 0 {
+ return Subnet{
+ address: Address(strings.Repeat("\x00", addrLen)),
+ mask: AddressMask(strings.Repeat("\x00", addrLen)),
+ }
+ }
+ if a.PrefixLen >= addrLen*8 {
+ return Subnet{
+ address: a.Address,
+ mask: AddressMask(strings.Repeat("\xff", addrLen)),
+ }
+ }
+
+ sa := make([]byte, addrLen)
+ sm := make([]byte, addrLen)
+ n := uint(a.PrefixLen)
+ for i := 0; i < addrLen; i++ {
+ if n >= 8 {
+ sa[i] = a.Address[i]
+ sm[i] = 0xff
+ n -= 8
+ continue
+ }
+ sm[i] = ^byte(0xff >> n)
+ sa[i] = a.Address[i] & sm[i]
+ n = 0
+ }
+
+ // For extra caution, call NewSubnet rather than directly creating the Subnet
+ // value. If that fails it indicates a serious bug in this code, so panic is
+ // in order.
+ s, err := NewSubnet(Address(sa), AddressMask(sm))
+ if err != nil {
+ panic("invalid subnet: " + err.Error())
+ }
+ return s
+}
+
// ProtocolAddress is an address and the network protocol it is associated
// with.
type ProtocolAddress struct {
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index fb3a0a5ee..8c0aacffa 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -195,3 +195,34 @@ func TestStatsString(t *testing.T) {
t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got)
}
}
+
+func TestAddressWithPrefixSubnet(t *testing.T) {
+ tests := []struct {
+ addr Address
+ prefixLen int
+ subnetAddr Address
+ subnetMask AddressMask
+ }{
+ {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"},
+ {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"},
+ {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"},
+ {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
+ {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
+ }
+ for _, tt := range tests {
+ ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen}
+ gotSubnet := ap.Subnet()
+ wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask)
+ if err != nil {
+ t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
+ continue
+ }
+ if gotSubnet != wantSubnet {
+ t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index d78a162b8..9254c3dea 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "icmp_packet_list",
out = "icmp_packet_list.go",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 451d3880e..3187b336b 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,7 +15,6 @@
package icmp
import (
- "encoding/binary"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -53,11 +52,11 @@ const (
//
// +stateify savable
type endpoint struct {
+ stack.TransportEndpointInfo
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
- transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
// The following fields are used to manage the receive queue, and are
@@ -74,27 +73,23 @@ type endpoint struct {
sndBufSize int
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
- id stack.TransportEndpointID
state endpointState
- // bindNICID and bindAddr are set via calls to Bind(). They are used to
- // reject attempts to send data or connect via a different NIC or
- // address
- bindNICID tcpip.NICID
- bindAddr tcpip.Address
- // regNICID is the default NIC to be used when callers don't specify a
- // NIC.
- regNICID tcpip.NICID
- route stack.Route `state:"manual"`
-}
-
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ route stack.Route `state:"manual"`
+ ttl uint8
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+}
+
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return &endpoint{
- stack: stack,
- netProto: netProto,
- transProto: transProto,
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ state: stateInitial,
}, nil
}
@@ -105,7 +100,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -144,6 +139,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
@@ -205,7 +201,30 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -256,12 +275,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicid := to.NIC
- if e.bindNICID != 0 {
- if nicid != 0 && nicid != e.bindNICID {
+ if e.BindNICID != 0 {
+ if nicid != 0 && nicid != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
}
toCopy := *to
@@ -272,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
// Find the enpoint.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -290,17 +309,17 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- switch e.netProto {
+ switch e.NetProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.id.LocalPort, v)
+ err = send4(route, e.ID.LocalPort, v, e.ttl)
case header.IPv6ProtocolNumber:
- err = send6(route, e.id.LocalPort, v)
+ err = send6(route, e.ID.LocalPort, v, e.ttl)
}
if err != nil {
@@ -315,8 +334,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(o)
+ e.mu.Unlock()
+ }
+
+ return nil
+}
+
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
return nil
}
@@ -332,6 +363,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -342,40 +385,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
return nil
- case *tcpip.ReceiveBufferSizeOption:
+ case *tcpip.TTLOption:
e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ *o = tcpip.TTLOption(e.ttl)
e.rcvMu.Unlock()
return nil
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident to the user-specified port. Sequence number should
- // already be set by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident)
-
hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(icmpv4, data)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ icmpv4.SetIdent(ident)
data = data[header.ICMPv4MinimumSize:]
// Linux performs these basic checks.
@@ -386,22 +422,24 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
}
-func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident. Sequence number is provided by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
-
- hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
copy(icmpv6, data)
- data = data[header.ICMPv6EchoMinimumSize:]
+ // Set the ident. Sequence number is provided by the user.
+ icmpv6.SetIdent(ident)
+ data = data[header.ICMPv6MinimumSize:]
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return tcpip.ErrInvalidEndpointState
@@ -410,18 +448,21 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv6.SetChecksum(0)
icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL())
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if header.IsV4MappedAddress(addr.Addr) {
return 0, tcpip.ErrNoRoute
}
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -442,16 +483,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
localPort := uint16(0)
switch e.state {
case stateBound, stateConnected:
- localPort = e.id.LocalPort
- if e.bindNICID == 0 {
+ localPort = e.ID.LocalPort
+ if e.BindNICID == 0 {
break
}
- if nicid != 0 && nicid != e.bindNICID {
+ if nicid != 0 && nicid != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
@@ -462,7 +503,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -484,9 +525,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.id = id
+ e.ID = id
e.route = r.Clone()
- e.regNICID = nicid
+ e.RegisterNICID = nicid
e.state = stateConnected
@@ -541,14 +582,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -595,8 +636,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.id = id
- e.regNICID = addr.NIC
+ e.ID = id
+ e.RegisterNICID = addr.NIC
// Mark endpoint as bound.
e.state = stateBound
@@ -619,8 +660,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.bindNICID = addr.NIC
- e.bindAddr = addr.Addr
+ e.BindNICID = addr.NIC
+ e.BindAddr = addr.Addr
return nil
}
@@ -631,9 +672,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
}, nil
}
@@ -647,9 +688,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
}, nil
}
@@ -675,17 +716,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
// Only accept echo replies.
- switch e.netProto {
+ switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(vv.First())
if h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(vv.First())
if h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
}
@@ -693,9 +736,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
- if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ if !e.rcvReady || e.rcvClosed {
+ e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
}
@@ -717,7 +768,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
pkt.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
-
+ e.stats.PacketsReceived.Increment()
// Notify any waiters that there's data to be read now.
if wasEmpty {
e.waiterQueue.Notify(waiter.EventIn)
@@ -733,3 +784,17 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
func (e *endpoint) State() uint32 {
return 0
}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index c587b96b6..9d263c0ec 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -76,19 +76,19 @@ func (e *endpoint) Resume(s *stack.Stack) {
var err *tcpip.Error
if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
if err != nil {
panic(err)
}
- e.id.LocalAddress = e.route.LocalAddress
- } else if len(e.id.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
+ e.ID.LocalAddress = e.route.LocalAddress
+ } else if len(e.ID.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
+ e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 7fdba5d56..bfb16f7c3 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -14,16 +14,14 @@
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and
-// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
-// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
-// calling stack.New(). Then endpoints can be created by passing
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package icmp
import (
- "encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,15 +33,9 @@ import (
)
const (
- // ProtocolName4 is the string representation of the icmp protocol name.
- ProtocolName4 = "icmp4"
-
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
- // ProtocolName6 is the string representation of the icmp protocol name.
- ProtocolName6 = "icmp6"
-
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
@@ -92,7 +84,7 @@ func (p *protocol) MinimumPacketSize() int {
case ProtocolNumber4:
return header.ICMPv4MinimumSize
case ProtocolNumber6:
- return header.ICMPv6EchoMinimumSize
+ return header.ICMPv6MinimumSize
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
@@ -101,16 +93,18 @@ func (p *protocol) MinimumPacketSize() int {
func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
switch p.number {
case ProtocolNumber4:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil
+ hdr := header.ICMPv4(v)
+ return 0, hdr.Ident(), nil
case ProtocolNumber6:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ hdr := header.ICMPv6(v)
+ return 0, hdr.Ident(), nil
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
return true
}
@@ -124,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
- })
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
- stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
- })
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 7241f6c19..fba598d51 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "packet_list",
out = "packet_list.go",
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 13e17e2a6..b4c660859 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -62,11 +62,10 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
+ stack.TransportEndpointInfo
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
- transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
@@ -84,18 +83,10 @@ type endpoint struct {
closed bool
connected bool
bound bool
- // registeredNIC is the NIC to which th endpoint is explicitly
- // registered. Is set when Connect or Bind are used to specify a NIC.
- registeredNIC tcpip.NICID
- // boundNIC and boundAddr are set on calls to Bind(). When callers
- // attempt actions that would invalidate the binding data (e.g. sending
- // data via a NIC other than boundNIC), the endpoint will return an
- // error.
- boundNIC tcpip.NICID
- boundAddr tcpip.Address
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route stack.Route `state:"manual"`
+ stats tcpip.TransportEndpointStats `state:"nosave"`
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -104,15 +95,17 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
if netProto != header.IPv4ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
- ep := &endpoint{
- stack: stack,
- netProto: netProto,
- transProto: transProto,
+ e := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
@@ -123,81 +116,82 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
// headers included. Because they're write-only, We don't need to
// register with the stack.
if !associated {
- ep.rcvBufSizeMax = 0
- ep.waiterQueue = nil
- return ep, nil
+ e.rcvBufSizeMax = 0
+ e.waiterQueue = nil
+ return e, nil
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
return nil, err
}
- return ep, nil
+ return e, nil
}
// Close implements tcpip.Endpoint.Close.
-func (ep *endpoint) Close() {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
- if ep.closed || !ep.associated {
+ if e.closed || !e.associated {
return
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
- ep.rcvMu.Lock()
- defer ep.rcvMu.Unlock()
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
// Clear the receive list.
- ep.rcvClosed = true
- ep.rcvBufSize = 0
- for !ep.rcvList.Empty() {
- ep.rcvList.Remove(ep.rcvList.Front())
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ e.rcvList.Remove(e.rcvList.Front())
}
- if ep.connected {
- ep.route.Release()
- ep.connected = false
+ if e.connected {
+ e.route.Release()
+ e.connected = false
}
- ep.closed = true
+ e.closed = true
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (ep *endpoint) ModerateRecvBuf(copied int) {}
+func (e *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (iptables.IPTables, error) {
- return ep.stack.IPTables(), nil
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
}
// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !ep.associated {
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ if !e.associated {
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
}
- ep.rcvMu.Lock()
+ e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
// endpoint is closed.
- if ep.rcvList.Empty() {
+ if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
- if ep.rcvClosed {
+ if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
err = tcpip.ErrClosedForReceive
}
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
return buffer.View{}, tcpip.ControlMessages{}, err
}
- packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ packet := e.rcvList.Front()
+ e.rcvList.Remove(packet)
+ e.rcvBufSize -= packet.data.Size()
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
if addr != nil {
*addr = packet.senderAddr
@@ -207,31 +201,54 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- ep.mu.RLock()
+ e.mu.RLock()
- if ep.closed {
- ep.mu.RUnlock()
+ if e.closed {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
- payloadBytes, err := payload.Get(payload.Size())
+ payloadBytes, err := p.FullPayload()
if err != nil {
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return 0, nil, err
}
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !ep.associated {
+ if !e.associated {
ip := header.IPv4(payloadBytes)
- if !ip.IsValid(payload.Size()) {
- ep.mu.RUnlock()
+ if !ip.IsValid(len(payloadBytes)) {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -252,66 +269,66 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
if opts.To == nil {
// If the user doesn't specify a destination, they should have
// connected to another address.
- if !ep.connected {
- ep.mu.RUnlock()
+ if !e.connected {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if ep.route.IsResolutionRequired() {
- savedRoute := &ep.route
+ if e.route.IsResolutionRequired() {
+ savedRoute := &e.route
// Promote lock to exclusive if using a shared route,
// given that it may need to change in finishWrite.
- ep.mu.RUnlock()
- ep.mu.Lock()
+ e.mu.RUnlock()
+ e.mu.Lock()
// Make sure that the route didn't change during the
// time we didn't hold the lock.
- if !ep.connected || savedRoute != &ep.route {
- ep.mu.Unlock()
+ if !e.connected || savedRoute != &e.route {
+ e.mu.Unlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
- n, ch, err := ep.finishWrite(payloadBytes, savedRoute)
- ep.mu.Unlock()
+ n, ch, err := e.finishWrite(payloadBytes, savedRoute)
+ e.mu.Unlock()
return n, ch, err
}
- n, ch, err := ep.finishWrite(payloadBytes, &ep.route)
- ep.mu.RUnlock()
+ n, ch, err := e.finishWrite(payloadBytes, &e.route)
+ e.mu.RUnlock()
return n, ch, err
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
- if ep.bound && nic != 0 && nic != ep.boundNIC {
- ep.mu.RUnlock()
+ if e.bound && nic != 0 && nic != e.BindNICID {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
// We don't support IPv6 yet, so this has to be an IPv4 address.
if len(opts.To.Addr) != header.IPv4AddressSize {
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
- // Find the route to the destination. If boundAddress is 0,
+ // Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
- route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false)
+ route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return 0, nil, err
}
- n, ch, err := ep.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, &route)
route.Release()
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return n, ch, err
}
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
// We may need to resolve the route (match a link layer address to the
// network address). If that requires blocking (e.g. to use ARP),
// return a channel on which the caller can wait.
@@ -324,16 +341,16 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch ep.netProto {
+ switch e.NetProto {
case header.IPv4ProtocolNumber:
- if !ep.associated {
+ if !e.associated {
if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
return 0, nil, err
}
break
}
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
+ if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
return 0, nil, err
}
@@ -345,7 +362,7 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -355,11 +372,11 @@ func (*endpoint) Disconnect() *tcpip.Error {
}
// Connect implements tcpip.Endpoint.Connect.
-func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
- if ep.closed {
+ if e.closed {
return tcpip.ErrInvalidEndpointState
}
@@ -369,15 +386,15 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
nic := addr.NIC
- if ep.bound {
- if ep.boundNIC == 0 {
+ if e.bound {
+ if e.BindNICID == 0 {
// If we're bound, but not to a specific NIC, the NIC
// in addr will be used. Nothing to do here.
} else if addr.NIC == 0 {
// If we're bound to a specific NIC, but addr doesn't
// specify a NIC, use the bound NIC.
- nic = ep.boundNIC
- } else if addr.NIC != ep.boundNIC {
+ nic = e.BindNICID
+ } else if addr.NIC != e.BindNICID {
// We're bound and addr specifies a NIC. They must be
// the same.
return tcpip.ErrInvalidEndpointState
@@ -385,53 +402,53 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Find a route to the destination.
- route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false)
+ route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false)
if err != nil {
return err
}
defer route.Release()
- if ep.associated {
+ if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
return err
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- ep.registeredNIC = nic
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.RegisterNICID = nic
}
// Save the route we've connected via.
- ep.route = route.Clone()
- ep.connected = true
+ e.route = route.Clone()
+ e.connected = true
return nil
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
-func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
- if !ep.connected {
+ if !e.connected {
return tcpip.ErrNotConnected
}
return nil
}
// Listen implements tcpip.Endpoint.Listen.
-func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept.
-func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
// Bind implements tcpip.Endpoint.Bind.
-func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
// Callers must provide an IPv4 address or no network address (for
// binding to a NIC, but not an address).
@@ -440,94 +457,100 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 {
+ if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
- if ep.associated {
+ if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
return err
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- ep.registeredNIC = addr.NIC
- ep.boundNIC = addr.NIC
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.RegisterNICID = addr.NIC
+ e.BindNICID = addr.NIC
}
- ep.boundAddr = addr.Addr
- ep.bound = true
+ e.BindAddr = addr.Addr
+ e.bound = true
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
// Readiness implements tcpip.Endpoint.Readiness.
-func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// The endpoint is always writable.
result := waiter.EventOut & mask
// Determine whether the endpoint is readable.
if (mask & waiter.EventIn) != 0 {
- ep.rcvMu.Lock()
- if !ep.rcvList.Empty() || ep.rcvClosed {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
result |= waiter.EventIn
}
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
}
return result
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
- ep.rcvMu.Lock()
- if !ep.rcvList.Empty() {
- p := ep.rcvList.Front()
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() {
+ p := e.rcvList.Front()
v = p.data.Size()
}
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
- ep.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -538,37 +561,45 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
- ep.rcvMu.Lock()
+func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
- if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax {
- ep.stack.Stats().DroppedPackets.Increment()
- ep.rcvMu.Unlock()
+ if e.rcvClosed {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
return
}
- if ep.bound {
+ if e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ if e.bound {
// If bound to a NIC, only accept data for that NIC.
- if ep.boundNIC != 0 && ep.boundNIC != route.NICID() {
- ep.rcvMu.Unlock()
+ if e.BindNICID != 0 && e.BindNICID != route.NICID() {
+ e.rcvMu.Unlock()
return
}
// If bound to an address, only accept data for that address.
- if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress {
- ep.rcvMu.Unlock()
+ if e.BindAddr != "" && e.BindAddr != route.RemoteAddress {
+ e.rcvMu.Unlock()
return
}
}
// If connected, only accept packets from the remote address we
// connected to.
- if ep.connected && ep.route.RemoteAddress != route.RemoteAddress {
- ep.rcvMu.Unlock()
+ if e.connected && e.route.RemoteAddress != route.RemoteAddress {
+ e.rcvMu.Unlock()
return
}
- wasEmpty := ep.rcvBufSize == 0
+ wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
packet := &packet{
@@ -581,20 +612,34 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b
combinedVV := netHeader.ToVectorisedView()
combinedVV.Append(vv)
packet.data = combinedVV.Clone(packet.views[:])
- packet.timestampNS = ep.stack.NowNanoseconds()
-
- ep.rcvList.PushBack(packet)
- ep.rcvBufSize += packet.data.Size()
+ packet.timestampNS = e.stack.NowNanoseconds()
- ep.rcvMu.Unlock()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.rcvMu.Unlock()
+ e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
if wasEmpty {
- ep.waiterQueue.Notify(waiter.EventIn)
+ e.waiterQueue.Notify(waiter.EventIn)
}
}
// State implements socket.Socket.State.
-func (ep *endpoint) State() uint32 {
+func (e *endpoint) State() uint32 {
return 0
}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 168953dec..a6c7cc43a 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,7 @@ func (ep *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if ep.connected {
var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false)
if err != nil {
panic(err)
}
@@ -81,12 +81,12 @@ func (ep *endpoint) Resume(s *stack.Stack) {
// If the endpoint is bound, re-bind.
if ep.bound {
- if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index 783c21e6b..a2512d666 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -20,13 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-type factory struct{}
+// EndpointFactory implements stack.UnassociatedEndpointFactory.
+type EndpointFactory struct{}
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
-
-func init() {
- stack.RegisterUnassociatedFactory(factory{})
-}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 1ee1a53f8..aed70e06f 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,7 +1,8 @@
-package(licenses = ["notice"])
-
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
go_template_instance(
name = "tcp_segment_list",
@@ -47,6 +48,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e9c5099ea..844959fa0 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -143,6 +143,15 @@ func decSynRcvdCount() {
synRcvdCount.Unlock()
}
+// synCookiesInUse() returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func synCookiesInUse() bool {
+ synRcvdCount.Lock()
+ v := synRcvdCount.value
+ synRcvdCount.Unlock()
+ return v >= SynRcvdCountThreshold
+}
+
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
@@ -220,7 +229,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
}
n := newEndpoint(l.stack, netProto, nil)
n.v6only = l.v6only
- n.id = s.id
+ n.ID = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
@@ -233,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
@@ -281,7 +290,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
- ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
@@ -302,14 +310,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- l.pendingEndpoints[n.id] = n
+ l.pendingEndpoints[n.ID] = n
l.pending.Add(1)
l.pendingMu.Unlock()
}
func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.id)
+ delete(l.pendingEndpoints, n.ID)
l.pending.Done()
l.pendingMu.Unlock()
}
@@ -354,6 +362,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
@@ -405,6 +414,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
}
decSynRcvdCount()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
} else {
@@ -412,6 +422,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// is full then drop the syn.
if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
}
@@ -430,7 +441,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TSEcr: opts.TSVal,
MSS: uint16(mss),
}
- sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
@@ -442,10 +453,32 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// complete the connection at the time of retransmit if
// the backlog has space.
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
}
+ if !synCookiesInUse() {
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ replyWithReset(s)
+ return
+ }
+
+ // Since SYN cookies are in use this is potentially an ACK to a
+ // SYN-ACK we sent but don't have a half open connection state
+ // as cookies are being used to protect against a potential SYN
+ // flood. In such cases validate the cookie and if valid create
+ // a fully connected endpoint and deliver to the accept queue.
+ //
+ // If not, silently drop the ACK to avoid leaking information
+ // when under a potential syn flood attack.
+ //
// Validate the cookie.
data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
if !ok || int(data) >= len(mssTable) {
@@ -475,6 +508,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
return
}
@@ -506,7 +540,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 00d2ae524..5ea036bea 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -238,6 +238,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.state = handshakeSynRcvd
h.ep.mu.Lock()
h.ep.state = StateSynRecv
+ ttl := h.ep.ttl
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
@@ -251,8 +252,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
SACKPermitted: rcvSynOpts.SACKPermitted,
MSS: h.ep.amss,
}
- sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
-
+ if ttl == 0 {
+ ttl = s.route.DefaultTTL()
+ }
+ h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
}
@@ -296,7 +299,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
}
@@ -383,6 +386,11 @@ func (h *handshake) resolveRoute() *tcpip.Error {
switch index {
case wakerForResolution:
if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ if err == tcpip.ErrNoLinkAddress {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ } else if err != nil {
+ h.ep.stats.SendErrors.NoRoute.Increment()
+ }
// Either success (err == nil) or failure.
return err
}
@@ -460,7 +468,8 @@ func (h *handshake) execute() *tcpip.Error {
synOpts.WS = -1
}
}
- sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
for h.state != handshakeCompleted {
switch index, _ := s.Fetch(true); index {
case wakerForResend:
@@ -469,7 +478,7 @@ func (h *handshake) execute() *tcpip.Error {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
case wakerForNotification:
n := h.ep.fetchNotifications()
@@ -579,16 +588,28 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
return options[:offset]
}
-func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
options := makeSynOptions(opts)
- err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
+ // We ignore SYN send errors and let the callers re-attempt send.
+ if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil {
+ e.stats.SendErrors.SynSendToNetworkFailed.Increment()
+ }
putOptions(options)
- return err
+ return nil
+}
+
+func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil {
+ e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
+ return err
+ }
+ e.stats.SegmentsSent.Increment()
+ return nil
}
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
@@ -624,12 +645,18 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ r.Stats().TCP.SegmentSendErrors.Increment()
+ return err
+ }
r.Stats().TCP.SegmentsSent.Increment()
if (flags & header.TCPFlagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
-
- return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl)
+ return nil
}
// makeOptions makes an options slice.
@@ -678,7 +705,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso)
+ err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso)
putOptions(options)
return err
}
@@ -720,13 +747,18 @@ func (e *endpoint) handleClose() *tcpip.Error {
return nil
}
-// resetConnectionLocked sends a RST segment and puts the endpoint in an error
-// state with the given error code. This method must only be called from the
-// protocol goroutine.
+// resetConnectionLocked puts the endpoint in an error state with the given
+// error code and sends a RST if and only if the error is not ErrConnectionReset
+// indicating that the connection is being reset due to receiving a RST. This
+// method must only be called from the protocol goroutine.
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ // Only send a reset if the connection is being aborted for a reason
+ // other than receiving a reset.
e.state = StateError
- e.hardError = err
+ e.HardError = err
+ if err != tcpip.ErrConnectionReset {
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ }
}
// completeWorkerLocked is called by the worker goroutine when it's about to
@@ -806,7 +838,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
- return tcpip.ErrConnectionReset
+ return tcpip.ErrTimeout
}
// RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
@@ -893,7 +925,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.mu.Lock()
e.state = StateError
- e.hardError = err
+ e.HardError = err
// Lock released below.
epilogue()
@@ -1068,6 +1100,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.workMu.Lock()
if err := funcs[v].f(); err != nil {
e.mu.Lock()
+ // Ensure we release all endpoint registration and route
+ // references as the connection is now in an error
+ // state.
+ e.workerCleanup = true
e.resetConnectionLocked(err)
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index c54610a87..dfaa4a559 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -42,7 +42,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) {
}
}
-func testV4Connect(t *testing.T, c *context.Context) {
+func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
@@ -55,12 +55,11 @@ func testV4Connect(t *testing.T, c *context.Context) {
// Receive SYN packet.
b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv4(t, b, synCheckers...)
tcp := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -76,14 +75,13 @@ func testV4Connect(t *testing.T, c *context.Context) {
})
// Receive ACK packet.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv4(t, c.GetPacket(), ackCheckers...)
// Wait for connection to be established.
select {
@@ -152,7 +150,7 @@ func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
testV4Connect(t, c)
}
-func testV6Connect(t *testing.T, c *context.Context) {
+func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
// Start connection attempt to IPv6 address.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
@@ -165,12 +163,11 @@ func testV6Connect(t *testing.T, c *context.Context) {
// Receive SYN packet.
b := c.GetV6Packet()
- checker.IPv6(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv6(t, b, synCheckers...)
tcp := header.TCP(header.IPv6(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -186,14 +183,13 @@ func testV6Connect(t *testing.T, c *context.Context) {
})
// Receive ACK packet.
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
// Wait for connection to be established.
select {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ac927569a..a1b784b49 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"fmt"
"math"
"strings"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -170,6 +172,101 @@ type rcvBufAutoTuneParams struct {
disabled bool
}
+// ReceiveErrors collect segment receive errors within transport layer.
+type ReceiveErrors struct {
+ tcpip.ReceiveErrors
+
+ // SegmentQueueDropped is the number of segments dropped due to
+ // a full segment queue.
+ SegmentQueueDropped tcpip.StatCounter
+
+ // ChecksumErrors is the number of segments dropped due to bad checksums.
+ ChecksumErrors tcpip.StatCounter
+
+ // ListenOverflowSynDrop is the number of times the listen queue overflowed
+ // and a SYN was dropped.
+ ListenOverflowSynDrop tcpip.StatCounter
+
+ // ListenOverflowAckDrop is the number of times the final ACK
+ // in the handshake was dropped due to overflow.
+ ListenOverflowAckDrop tcpip.StatCounter
+
+ // ZeroRcvWindowState is the number of times we advertised
+ // a zero receive window when rcvList is full.
+ ZeroRcvWindowState tcpip.StatCounter
+}
+
+// SendErrors collect segment send errors within the transport layer.
+type SendErrors struct {
+ tcpip.SendErrors
+
+ // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent
+ // to the network endpoint.
+ SegmentSendToNetworkFailed tcpip.StatCounter
+
+ // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent
+ // to the network endpoint.
+ SynSendToNetworkFailed tcpip.StatCounter
+
+ // Retransmits is the number of TCP segments retransmitted.
+ Retransmits tcpip.StatCounter
+
+ // FastRetransmit is the number of segments retransmitted in fast
+ // recovery.
+ FastRetransmit tcpip.StatCounter
+
+ // Timeouts is the number of times the RTO expired.
+ Timeouts tcpip.StatCounter
+}
+
+// Stats holds statistics about the endpoint.
+type Stats struct {
+ // SegmentsReceived is the number of TCP segments received that
+ // the transport layer successfully parsed.
+ SegmentsReceived tcpip.StatCounter
+
+ // SegmentsSent is the number of TCP segments sent.
+ SegmentsSent tcpip.StatCounter
+
+ // FailedConnectionAttempts is the number of times we saw Connect and
+ // Accept errors.
+ FailedConnectionAttempts tcpip.StatCounter
+
+ // ReceiveErrors collects segment receive errors within the
+ // transport layer.
+ ReceiveErrors ReceiveErrors
+
+ // ReadErrors collects segment read errors from an endpoint read call.
+ ReadErrors tcpip.ReadErrors
+
+ // SendErrors collects segment send errors within the transport layer.
+ SendErrors SendErrors
+
+ // WriteErrors collects segment write errors from an endpoint write call.
+ WriteErrors tcpip.WriteErrors
+}
+
+// IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+// marker interface.
+func (*Stats) IsEndpointStats() {}
+
+// EndpointInfo holds useful information about a transport endpoint which
+// can be queried by monitoring tools.
+//
+// +stateify savable
+type EndpointInfo struct {
+ stack.TransportEndpointInfo
+
+ // HardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. HardError is protected by endpoint mu.
+ HardError *tcpip.Error `state:".(string)"`
+}
+
+// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+// marker interface.
+func (*EndpointInfo) IsEndpointInfo() {}
+
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -178,6 +275,8 @@ type rcvBufAutoTuneParams struct {
//
// +stateify savable
type endpoint struct {
+ EndpointInfo
+
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
// it, but other goroutines (e.g., send) may call TryLock() to eagerly
@@ -186,8 +285,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
- stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
+ stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
// lastError represents the last error that the endpoint reported;
@@ -218,7 +316,6 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
- id stack.TransportEndpointID
state EndpointState `state:".(EndpointState)"`
@@ -226,6 +323,7 @@ type endpoint struct {
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
route stack.Route `state:"manual"`
+ ttl uint8
v6only bool
isConnectNotified bool
// TCP should never broadcast but Linux nevertheless supports enabling/
@@ -240,11 +338,6 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
- // hardError is meaningful only when state is stateError, it stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. hardError is protected by mu.
- hardError *tcpip.Error `state:".(string)"`
-
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -280,6 +373,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -393,13 +489,19 @@ type endpoint struct {
probe stack.TCPProbeFunc `state:"nosave"`
// The following are only used to assist the restore run to re-connect.
- bindAddress tcpip.Address
connectingAddress tcpip.Address
// amss is the advertised MSS to the peer by this endpoint.
amss uint16
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
gso *stack.GSO
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats Stats `state:"nosave"`
}
// StopWork halts packet processing. Only to be used in tests.
@@ -427,10 +529,15 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: stack,
- netProto: netProto,
+ stack: s,
+ EndpointInfo: EndpointInfo{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
+ },
waiterQueue: waiterQueue,
state: StateInitial,
rcvBufSize: DefaultReceiveBufferSize,
@@ -446,26 +553,26 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
}
var ss SendBufferSizeOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
e.sndBufSize = ss.Default
}
var rs ReceiveBufferSizeOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
e.rcvBufSize = rs.Default
}
var cs tcpip.CongestionControlOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
var mrb tcpip.ModerateReceiveBufferOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
e.rcvAutoParams.disabled = !bool(mrb)
}
- if p := stack.GetTCPProbe(); p != nil {
+ if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -564,11 +671,11 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -625,12 +732,12 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -731,11 +838,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
- he := e.hardError
+ he := e.HardError
e.mu.RUnlock()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
@@ -744,6 +852,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
e.mu.RUnlock()
+ if err == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
return v, tcpip.ControlMessages{}, err
}
@@ -787,7 +898,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, e.hardError
+ return 0, e.HardError
default:
return 0, tcpip.ErrClosedForSend
}
@@ -806,7 +917,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
@@ -818,50 +929,57 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
if err != nil {
e.sndBufMu.Unlock()
e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
-
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
}
- // Copy in memory without holding sndBufMu so that worker goroutine can
- // make progress independent of this operation.
- v, perr := p.Get(avail)
- if perr != nil {
+ // Fetch data.
+ v, perr := p.Payload(avail)
+ if perr != nil || len(v) == 0 {
+ if opts.Atomic { // See above.
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ }
+ // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a
- // write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- return 0, nil, err
- }
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
- // Discard any excess data copied in due to avail being reduced due to a
- // simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
}
// Add data to the send queue.
- l := len(v)
- s := newSegmentFromView(&e.route, e.id, v)
- e.sndBufUsed += l
- e.sndBufInQueue += seqnum.Size(l)
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
-
e.sndBufMu.Unlock()
// Release the endpoint lock to prevent deadlocks due to lock
// order inversion when acquiring workMu.
@@ -875,7 +993,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return int64(l), nil, nil
+
+ return int64(len(v)), nil, nil
}
// Peek reads data without consuming it from the endpoint.
@@ -889,8 +1008,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.state; !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.hardError
+ return 0, tcpip.ControlMessages{}, e.HardError
}
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
@@ -899,6 +1019,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.state.connected() {
+ e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
@@ -946,62 +1067,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
-
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1065,9 +1133,87 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
@@ -1082,6 +1228,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.v6only = v != 0
return nil
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+ return nil
+
case tcpip.KeepaliveEnabledOption:
e.keepalive.Lock()
e.keepalive.enabled = v != 0
@@ -1150,6 +1302,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Linux returns ENOENT when an invalid congestion
// control algorithm is specified.
return tcpip.ErrNoSuchFile
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1176,6 +1345,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1198,18 +1379,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.sndBufMu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
- e.rcvListMu.Unlock()
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
@@ -1246,6 +1415,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1255,7 +1434,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrUnknownProtocolOption
}
@@ -1269,6 +1448,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.TTLOption:
+ e.mu.Lock()
+ *o = tcpip.TTLOption(e.ttl)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.mu.RLock()
@@ -1333,13 +1518,25 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ *o = tcpip.IPv4TOSOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
+ case *tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only {
@@ -1355,7 +1552,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1369,7 +1566,12 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- return e.connect(addr, true, true)
+ err := e.connect(addr, true, true)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
}
// connect connects the endpoint to its peer. In the normal non-S/R case, the
@@ -1378,14 +1580,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// created (so no new handshaking is done); for stack-accepted connections not
// yet accepted by the app, they are restored without running the main goroutine
// here.
-func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- defer func() {
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- }
- }()
connectingAddr := addr.Addr
@@ -1430,29 +1627,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.hardError
+ return e.HardError
default:
return tcpip.ErrInvalidEndpointState
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
defer r.Release()
- origID := e.id
+ origID := e.ID
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- e.id.LocalAddress = r.LocalAddress
- e.id.RemoteAddress = r.RemoteAddress
- e.id.RemotePort = addr.Port
+ e.ID.LocalAddress = r.LocalAddress
+ e.ID.RemoteAddress = r.RemoteAddress
+ e.ID.RemotePort = addr.Port
- if e.id.LocalPort != 0 {
+ if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1461,20 +1658,35 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// one. Make sure that it isn't one that will result in the same
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
- sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- if sameAddr && p == e.id.RemotePort {
+ sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.PortSeed())
+ h.Write([]byte(e.ID.LocalAddress))
+ h.Write([]byte(e.ID.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
- id := e.id
+ id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
- e.id = id
+ e.ID = id
return true, nil
case tcpip.ErrPortInUse:
return false, nil
@@ -1490,7 +1702,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1509,7 +1721,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
e.segmentQueue.mu.Lock()
for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
- s.id = e.id
+ s.id = e.ID
s.route = r.Clone()
e.sndWaker.Assert()
}
@@ -1569,7 +1781,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Queue fin segment.
- s := newSegmentFromView(&e.route, e.id, nil)
+ s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
@@ -1597,14 +1809,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
-func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
+ err := e.listen(backlog)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
+}
+
+func (e *endpoint) listen(backlog int) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- defer func() {
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- }
- }()
// Allow the backlog to be adjusted if the endpoint is not shutting down.
// When the endpoint shuts down, it sets workerCleanup to true, and from
@@ -1630,11 +1846,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
// Endpoint must be bound before it can transition to listen mode.
if e.state != StateBound {
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1698,7 +1915,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
return tcpip.ErrAlreadyBound
}
- e.bindAddress = addr.Addr
+ e.BindAddr = addr.Addr
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
@@ -1715,26 +1932,26 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
e.isPortReserved = true
e.effectiveNetProtos = netProtos
- e.id.LocalPort = port
+ e.ID.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
- e.id.LocalPort = 0
- e.id.LocalAddress = ""
+ e.ID.LocalPort = 0
+ e.ID.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
@@ -1745,7 +1962,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
e.boundNICID = nic
- e.id.LocalAddress = addr.Addr
+ e.ID.LocalAddress = addr.Addr
}
// Mark endpoint as bound.
@@ -1760,8 +1977,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
NIC: e.boundNICID,
}, nil
}
@@ -1776,8 +1993,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
return tcpip.FullAddress{
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
NIC: e.boundNICID,
}, nil
}
@@ -1789,6 +2006,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
s.decRef()
return
}
@@ -1796,11 +2014,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if !s.csumValid {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
s.decRef()
return
}
e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ e.stats.SegmentsReceived.Increment()
if (s.flags & header.TCPFlagRst) != 0 {
e.stack.Stats().TCP.ResetsReceived.Increment()
}
@@ -1811,6 +2031,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
} else {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
s.decRef()
}
}
@@ -1860,6 +2081,7 @@ func (e *endpoint) readyToRead(s *segment) {
// that a subsequent read of the segment will correctly trigger
// a non-zero notification.
if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
e.zeroWindow = true
}
e.rcvList.PushBack(s)
@@ -2012,7 +2234,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Copy EndpointID.
e.mu.Lock()
- s.ID = stack.TCPEndpointID(e.id)
+ s.ID = stack.TCPEndpointID(e.ID)
e.mu.Unlock()
// Copy endpoint rcv state.
@@ -2119,7 +2341,7 @@ func (e *endpoint) initGSO() {
gso.Type = stack.GSOTCPv6
gso.L3HdrLen = header.IPv6MinimumSize
default:
- panic(fmt.Sprintf("Unknown netProto: %v", e.netProto))
+ panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto))
}
gso.NeedsCsum = true
gso.CsumOffset = header.TCPChecksumOffset
@@ -2135,6 +2357,20 @@ func (e *endpoint) State() uint32 {
return uint32(e.state)
}
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.EndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
func mssForRoute(r *stack.Route) uint16 {
return uint16(r.MTU() - header.TCPMinimumSize)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 831389ec7..eae17237e 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() {
case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
- panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
}
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.mu.Unlock()
@@ -190,10 +190,10 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind := func() {
e.state = StateInitial
- if len(e.bindAddress) == 0 {
- e.bindAddress = e.id.LocalAddress
+ if len(e.BindAddr) == 0 {
+ e.BindAddr = e.ID.LocalAddress
}
- if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
+ if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil {
panic("endpoint binding failed: " + err.String())
}
}
@@ -202,19 +202,19 @@ func (e *endpoint) Resume(s *stack.Stack) {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
bind()
if len(e.connectingAddress) == 0 {
- e.connectingAddress = e.id.RemoteAddress
+ e.connectingAddress = e.ID.RemoteAddress
// This endpoint is accepted by netstack but not yet by
// the app. If the endpoint is IPv6 but the remote
// address is IPv4, we need to connect as IPv6 so that
// dual-stack mode can be properly activated.
- if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
- e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress
}
}
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
- if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
connectedLoading.Done()
@@ -236,7 +236,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
listenLoading.Wait()
bind()
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
connectingLoading.Done()
@@ -288,21 +288,21 @@ func (e *endpoint) loadLastError(s string) {
}
// saveHardError is invoked by stateify.
-func (e *endpoint) saveHardError() string {
- if e.hardError == nil {
+func (e *EndpointInfo) saveHardError() string {
+ if e.HardError == nil {
return ""
}
- return e.hardError.String()
+ return e.HardError.String()
}
// loadHardError is invoked by stateify.
-func (e *endpoint) loadHardError(s string) {
+func (e *EndpointInfo) loadHardError(s string) {
if s == "" {
return
}
- e.hardError = loadError(s)
+ e.HardError = loadError(s)
}
var messageToError map[string]*tcpip.Error
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index ee04dcfcc..db40785d3 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -14,7 +14,7 @@
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// activated on the stack by passing tcp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -34,9 +34,6 @@ import (
)
const (
- // ProtocolName is the string representation of the tcp protocol name.
- ProtocolName = "tcp"
-
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
@@ -129,7 +126,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
@@ -156,7 +153,7 @@ func replyWithReset(s *segment) {
ack := s.sequenceNumber.Add(s.logicalLen())
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
// SetOption implements TransportProtocol.SetOption.
@@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
- congestionControl: ccReno,
- availableCongestionControl: []string{ccReno, ccCubic},
- }
- })
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 1f9b1e0ef..8332a0179 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -417,6 +417,7 @@ func (s *sender) resendSegment() {
s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
s.sendSegment(seg)
s.ep.stack.Stats().TCP.FastRetransmit.Increment()
+ s.ep.stats.SendErrors.FastRetransmit.Increment()
// Run SetPipe() as per RFC 6675 section 5 Step 4.4
s.SetPipe()
@@ -435,6 +436,7 @@ func (s *sender) retransmitTimerExpired() bool {
}
s.ep.stack.Stats().TCP.Timeouts.Increment()
+ s.ep.stats.SendErrors.Timeouts.Increment()
// Give up if we've waited more than a minute since the last resend.
if s.rto >= 60*time.Second {
@@ -664,7 +666,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
segEnd = seg.sequenceNumber.Add(1)
// Transition to FIN-WAIT1 state since we're initiating an active close.
s.ep.mu.Lock()
- s.ep.state = StateFinWait1
+ switch s.ep.state {
+ case StateCloseWait:
+ // We've already received a FIN and are now sending our own. The
+ // sender is now awaiting a final ACK for this FIN.
+ s.ep.state = StateLastAck
+ default:
+ s.ep.state = StateFinWait1
+ }
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
@@ -1181,6 +1190,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
func (s *sender) sendSegment(seg *segment) *tcpip.Error {
if !seg.xmitTime.IsZero() {
s.ep.stack.Stats().TCP.Retransmits.Increment()
+ s.ep.stats.SendErrors.Retransmits.Increment()
if s.sndCwnd < s.sndSsthresh {
s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 272bbcdbd..782d7b42c 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
enableCUBIC(t, c)
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) {
t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
+ t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
+ t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ }
+
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 4e7f1a740..afea124ec 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) {
t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
+ t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ }
+
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
+ t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ }
+
c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
// Acknowledge all pending data to recover point.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index f79b8ec5f..6d022a266 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -97,9 +97,12 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
}
}
@@ -122,6 +125,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ }
}
func TestTCPSegmentsSentIncrement(t *testing.T) {
@@ -131,11 +137,14 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
+ t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ }
}
func TestTCPResetsSentIncrement(t *testing.T) {
@@ -190,21 +199,122 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
+// a RST if an ACK is received on the listening socket for which there is no
+// active handshake in progress and we are not using SYN cookies.
+func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ c.GetPacket()
+
+ // Now resend the same ACK, this ACK should generate a RST as there
+ // should be no endpoint in SYN-RCVD state and we are not using
+ // syn-cookies yet. The reason we send the same ACK is we need a valid
+ // cookie(IRS) generated by the netstack without which the ACK will be
+ // rejected.
+ c.SendPacket(nil, ackHeaders)
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+}
+
func TestTCPResetsReceivedIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- ackNum := seqnum.Value(789)
+ iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(ackNum, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: c.IRS.Add(2),
- AckNum: ackNum.Add(2),
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
RcvWnd: rcvWnd,
Flags: header.TCPFlagRst,
})
@@ -214,18 +324,43 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
}
}
+func TestTCPResetsDoNotGenerateResets(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.ResetsReceived.Value() + 1
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ Flags: header.TCPFlagRst,
+ })
+
+ if got := stats.TCP.ResetsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ }
+ c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
+}
+
func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -241,7 +376,7 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -291,7 +426,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -339,11 +474,172 @@ func TestSimpleReceive(t *testing.T) {
)
}
+func TestTOSV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ c.EP = ep
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+ t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ }
+
+ var v tcpip.IPv4TOSOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv4TOSOption(tos); v != want {
+ t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testV4Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestTrafficClassV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+ t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
+ }
+
+ var v tcpip.IPv6TrafficClassOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+ t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetV6Packet()
+ checker.IPv6(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -431,8 +727,7 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
// Create a new connection with initial window size of 10.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -505,7 +800,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -574,7 +869,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -659,7 +954,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -672,14 +967,17 @@ func TestShutdownRead(t *testing.T) {
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
}
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ }
}
func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -729,6 +1027,11 @@ func TestFullWindowReceive(t *testing.T) {
t.Fatalf("got data = %v, want = %v", v, data)
}
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ }
+
// Check that we get an ACK for the newly non-zero window.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
@@ -746,11 +1049,9 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.Cleanup()
// Start off with a window size of 10, then shrink it to 5.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
- opt = 5
- if err := c.EP.SetSockOpt(opt); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
t.Fatalf("SetSockOpt failed: %v", err)
}
@@ -850,7 +1151,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -891,7 +1192,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, nil)
+ c.CreateConnected(789, 0, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -949,8 +1250,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -984,8 +1284,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1025,7 +1324,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1098,7 +1397,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1167,8 +1466,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the window size such that a window scale of 4 will be used.
const wnd = 65535 * 10
const ws = uint32(4)
- opt := tcpip.ReceiveBufferSizeOption(wnd)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1273,7 +1571,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Prevent the endpoint from processing packets.
test.stop(c.EP)
@@ -1323,7 +1621,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1371,7 +1669,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1453,7 +1751,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -1569,16 +1867,44 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
+func TestSetTTL(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := context.New(t, 65535)
+ defer c.Cleanup()
+
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
+ t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+
+ checker.IPv4(t, b, checker.TTL(wantTTL))
+ })
+ }
+}
+
func TestActiveSendMSSLessThanMTU(t *testing.T) {
const maxPayload = 100
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -1601,7 +1927,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1745,7 +2071,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 2
- if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1847,7 +2173,7 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1878,13 +2204,20 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
}
}
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ }
+ if tcp.EndpointState(c.EP.State()) != tcp.StateError {
+ t.Fatalf("got EP state is not StateError")
+ }
}
func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1909,7 +2242,7 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -1952,7 +2285,7 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2006,7 +2339,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
@@ -2077,7 +2410,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -2165,7 +2498,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
@@ -2251,7 +2584,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -2383,7 +2716,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
@@ -2433,7 +2766,7 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
@@ -2449,12 +2782,23 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
+ t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ }
+ // Ensure there were no errors during handshake. If these stats have
+ // incremented, then the connection should not have been established.
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ }
}
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
vv := c.BuildSegment(nil, &context.Headers{
@@ -2473,12 +2817,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ }
}
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
@@ -2499,6 +2846,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
if got := stats.TCP.ChecksumErrors.Value(); got != want {
t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
+ }
}
func TestReceivedSegmentQueuing(t *testing.T) {
@@ -2509,7 +2859,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
data := []byte{1, 2, 3}
@@ -2555,7 +2905,7 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -2730,8 +3080,8 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2743,8 +3093,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2754,7 +3104,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
}
func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2800,7 +3153,10 @@ func TestDefaultBufferSizes(t *testing.T) {
}
func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2819,37 +3175,96 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
- s := stack.New([]string{
- ipv4.ProtocolName,
- ipv6.ProtocolName,
- }, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ ipv6.NewProtocol(),
+ },
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
id := loopback.New()
if testing.Verbose() {
@@ -3105,7 +3520,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -3182,7 +3597,7 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3356,7 +3771,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
@@ -3459,8 +3874,8 @@ func TestKeepalive(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
}
}
@@ -3886,6 +4301,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ }
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -3924,6 +4342,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ // Expect InvalidEndpointState errors on a read at this point.
+ if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ }
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
+ t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ }
+
if err := ep.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 272481aa0..ef823e4ae 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -137,7 +137,10 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Allow minimum send/receive buffer sizes to be 1 during tests.
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
@@ -150,11 +153,19 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- id, linkEP := channel.New(1000, mtu, "")
+ ep := channel.New(1000, mtu, "")
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -180,7 +191,7 @@ func New(t *testing.T, mtu uint32) *Context {
return &Context{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
}
}
@@ -267,7 +278,7 @@ func (c *Context) GetPacketNonBlocking() []byte {
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2))
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
if len(buf) > maxTotalSize {
buf = buf[:maxTotalSize]
}
@@ -286,9 +297,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
icmp.SetType(typ)
icmp.SetCode(code)
-
- copy(icmp[header.ICMPv4PayloadOffset:], p1)
- copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2)
+ const icmpv4VariableHeaderOffset = 4
+ copy(icmp[icmpv4VariableHeaderOffset:], p1)
+ copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
@@ -511,7 +522,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
}
// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
@@ -584,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -597,11 +604,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 4bec48c0f..43fcc27f0 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index ac2666f69..c9460aa0d 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,7 +1,8 @@
-package(licenses = ["notice"])
-
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
go_template_instance(
name = "udp_packet_list",
@@ -50,6 +51,7 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index ac5905772..6e87245b7 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "math"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -37,15 +36,35 @@ type udpPacket struct {
views [8]buffer.View `state:"nosave"`
}
-type endpointState int
+// EndpointState represents the state of a UDP endpoint.
+type EndpointState uint32
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnected
+ StateClosed
)
+// String implements fmt.Stringer.String.
+func (s EndpointState) String() string {
+ switch s {
+ case StateInitial:
+ return "INITIAL"
+ case StateBound:
+ return "BOUND"
+ case StateConnected:
+ return "CONNECTING"
+ case StateClosed:
+ return "CLOSED"
+ default:
+ return "UNKNOWN"
+ }
+}
+
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -55,10 +74,11 @@ const (
//
// +stateify savable
type endpoint struct {
+ stack.TransportEndpointInfo
+
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
// The following fields are used to manage the receive queue, and are
@@ -73,20 +93,23 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
- id stack.TransportEndpointID
- state endpointState
- bindNICID tcpip.NICID
- regNICID tcpip.NICID
+ state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
v6only bool
+ ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
multicastLoop bool
reusePort bool
+ bindToDevice tcpip.NICID
broadcast bool
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -101,6 +124,9 @@ type endpoint struct {
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats tcpip.TransportEndpointStats `state:"nosave"`
}
// +stateify savable
@@ -109,10 +135,13 @@ type multicastMembership struct {
multicastAddr tcpip.Address
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
- stack: stack,
- netProto: netProto,
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
// requests.
@@ -130,6 +159,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ state: StateInitial,
}
}
@@ -140,13 +170,13 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
- case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ case StateBound, StateConnected:
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
e.multicastMemberships = nil
@@ -163,7 +193,7 @@ func (e *endpoint) Close() {
e.route.Release()
// Update the state.
- e.state = stateClosed
+ e.state = StateClosed
e.mu.Unlock()
@@ -186,6 +216,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
@@ -211,11 +242,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
- case stateInitial:
- case stateConnected:
+ case StateInitial:
+ case StateConnected:
return false, nil
- case stateBound:
+ case StateBound:
if to == nil {
return false, tcpip.ErrDestinationRequired
}
@@ -232,7 +263,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return true, nil
}
@@ -248,7 +279,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
- localAddr := e.id.LocalAddress
+ localAddr := e.ID.LocalAddress
if isBroadcastOrMulticast(localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
@@ -273,17 +304,35 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- if p.Size() > math.MaxUint16 {
- // Payload can't possibly fit in a packet.
- return 0, nil, tcpip.ErrMessageTooLong
- }
-
to := opts.To
e.mu.RLock()
@@ -322,7 +371,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
- if e.state != stateConnected {
+ if e.state != StateConnected {
return 0, nil, tcpip.ErrInvalidEndpointState
}
}
@@ -330,12 +379,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicid := to.NIC
- if e.bindNICID != 0 {
- if nicid != 0 && nicid != e.bindNICID {
+ if e.BindNICID != 0 {
+ if nicid != 0 && nicid != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
}
if to.Addr == header.IPv4Broadcast && !e.broadcast {
@@ -366,17 +415,25 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
+ if len(v) > header.UDPMaximumPacketSize {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
+
+ ttl := e.ttl
+ useDefaultTTL := ttl == 0
- ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
ttl = e.multicastTTL
+ // Multicast allows a 0 TTL.
+ useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -387,12 +444,17 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
@@ -400,12 +462,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
e.v6only = v != 0
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -440,7 +507,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
- if e.bindNICID != 0 && e.bindNICID != nic {
+ if e.BindNICID != 0 && e.BindNICID != nic {
return tcpip.ErrInvalidEndpointState
}
@@ -467,7 +534,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
} else {
- nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrUnknownDevice
@@ -484,7 +551,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
- if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -505,7 +572,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
} else {
- nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrUnknownDevice
@@ -527,7 +594,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -544,12 +611,39 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.reusePort = v != 0
e.mu.Unlock()
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
e.mu.Unlock()
return nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ return nil
}
return nil
}
@@ -566,7 +660,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -576,21 +683,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrUnknownProtocolOption
}
@@ -604,6 +699,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.TTLOption:
+ e.mu.Lock()
+ *o = tcpip.TTLOption(e.ttl)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.MulticastTTLOption:
e.mu.Lock()
*o = tcpip.MulticastTTLOption(e.multicastTTL)
@@ -638,6 +739,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = tcpip.BindToDeviceOption("")
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -653,6 +764,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ *o = tcpip.IPv4TOSOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
+ case *tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -660,7 +783,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -683,14 +806,21 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
+ if useDefaultTTL {
+ ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ r.Stats().UDP.PacketSendErrors.Increment()
+ return err
+ }
+
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
-
- return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl)
+ return nil
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if len(addr.Addr) == 0 {
return netProto, nil
}
@@ -707,14 +837,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
}
// Fail if we are bound to an IPv6 address.
- if !allowMismatch && len(e.id.LocalAddress) == 16 {
+ if !allowMismatch && len(e.ID.LocalAddress) == 16 {
return 0, tcpip.ErrNetworkUnreachable
}
}
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -726,28 +856,32 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return nil
}
id := stack.TransportEndpointID{}
// Exclude ephemerally bound endpoints.
- if e.bindNICID != 0 || e.id.LocalAddress == "" {
+ if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.id.LocalPort,
- LocalAddress: e.id.LocalAddress,
+ LocalPort: e.ID.LocalPort,
+ LocalAddress: e.ID.LocalAddress,
}
- id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
- e.state = stateBound
+ e.state = StateBound
} else {
- e.state = stateInitial
+ if e.ID.LocalPort != 0 {
+ // Release the ephemeral port.
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ }
+ e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.id = id
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.ID = id
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
@@ -772,18 +906,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicid := addr.NIC
var localPort uint16
switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
- localPort = e.id.LocalPort
- if e.bindNICID == 0 {
+ case StateInitial:
+ case StateBound, StateConnected:
+ localPort = e.ID.LocalPort
+ if e.BindNICID == 0 {
break
}
- if nicid != 0 && nicid != e.bindNICID {
+ if nicid != 0 && nicid != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
@@ -795,13 +929,13 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
id := stack.TransportEndpointID{
- LocalAddress: e.id.LocalAddress,
+ LocalAddress: e.ID.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
}
- if e.state == stateInitial {
+ if e.state == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -822,17 +956,17 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Remove the old registration.
- if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ if e.ID.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
}
- e.id = id
+ e.ID = id
e.route = r.Clone()
e.dstPort = addr.Port
- e.regNICID = nicid
+ e.RegisterNICID = nicid
e.effectiveNetProtos = netProtos
- e.state = stateConnected
+ e.state = StateConnected
e.rcvMu.Lock()
e.rcvReady = true
@@ -854,7 +988,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -885,17 +1019,17 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
- if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ if e.ID.LocalPort == 0 {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
}
@@ -903,7 +1037,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -941,12 +1075,12 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.id = id
- e.regNICID = nicid
+ e.ID = id
+ e.RegisterNICID = nicid
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = stateBound
+ e.state = StateBound
e.rcvMu.Lock()
e.rcvReady = true
@@ -967,7 +1101,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the effective NICID generated by bindLocked.
- e.bindNICID = e.regNICID
+ e.BindNICID = e.RegisterNICID
return nil
}
@@ -978,9 +1112,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
}, nil
}
@@ -989,14 +1123,14 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
}, nil
}
@@ -1026,6 +1160,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if int(hdr.Length()) > vv.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
@@ -1033,11 +1168,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
+ e.stats.PacketsReceived.Increment()
// Drop the packet if our buffer is currently full.
- if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ if !e.rcvReady || e.rcvClosed {
+ e.rcvMu.Unlock()
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
}
@@ -1069,10 +1213,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
-// State implements socket.Socket.State.
+// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- // TODO(b/112063468): Translate internal state to values returned by Linux.
- return 0
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
+}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
}
func isBroadcastOrMulticast(a tcpip.Address) bool {
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 5cbb56120..b227e353b 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -72,12 +72,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
panic(err)
}
}
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return
}
@@ -92,14 +92,14 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ if e.state == StateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
panic(err)
}
- } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound
+ } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
@@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
// Our saved state had a port, but we don't actually have a
// reservation. We need to remove the port from our state, but still
// pass it to the reservation machinery.
- id := e.id
- e.id.LocalPort = 0
- e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ id := e.ID
+ e.ID.LocalPort = 0
+ e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a874fc9fd..d399ec722 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -74,17 +74,17 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
- ep.id = r.id
+ ep.ID = r.id
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
- ep.regNICID = r.route.NICID()
+ ep.RegisterNICID = r.route.NICID()
- ep.state = stateConnected
+ ep.state = StateConnected
ep.rcvMu.Lock()
ep.rcvReady = true
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index f76e7fbe1..de026880f 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -69,7 +66,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
+
+ // Only send ICMP error if the address is not a multicast/broadcast
+ // v4/v6 address or the source is not the unspecified address.
+ //
+ // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
+ return true
+ }
+
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ //
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported; or
+ //
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ switch len(id.LocalAddress) {
+ case header.IPv4AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
+ return true
+ }
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 where it mentioned that at least 8
+ // bytes of the payload must be included. Today linux and other
+ // systems implement the] RFC1812 definition and not the original
+ // RFC 1122 requirement.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4DstUnreachable)
+ pkt.SetCode(header.ICMPv4PortUnreachable)
+ pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+
+ case header.IPv6AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
+ return true
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
+ pkt.SetType(header.ICMPv6DstUnreachable)
+ pkt.SetCode(header.ICMPv6PortUnreachable)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ }
return true
}
@@ -83,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 9da6edce2..b724d788c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -17,7 +17,6 @@ package udp_test
import (
"bytes"
"fmt"
- "math"
"math/rand"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -274,13 +274,17 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ ep := channel.New(256, mtu, "")
+ wep := stack.LinkEndpoint(ep)
- id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -306,7 +310,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -380,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h)
+ c.injectV4Packet(payload, &h, true /* valid */)
} else {
- c.injectV6Packet(payload, &h)
+ c.injectV6Packet(payload, &h, true /* valid */)
}
}
// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -405,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
+ l := uint16(header.UDPMinimumSize + len(payload))
+ if !valid {
+ // Change the UDP payload length to corrupt the header
+ // as requested by the caller.
+ l++
+ }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
+ Length: l,
})
// Calculate the UDP pseudo-header checksum.
@@ -422,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-// injectV6Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
+// injectV4Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -461,101 +475,78 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
}
func newPayload() []byte {
- b := make([]byte, 30+rand.Intn(100))
+ return newMinPayload(30)
+}
+
+func newMinPayload(minSize int) []byte {
+ b := make([]byte, minSize+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
-func TestBindPortReuse(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- c.createEndpoint(ipv6.ProtocolNumber)
-
- var eps [5]tcpip.Endpoint
- reusePortOpt := tcpip.ReusePortOption(1)
-
- pollChannel := make(chan tcpip.Endpoint)
- for i := 0; i < len(eps); i++ {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- var err *tcpip.Error
- eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(eps[i])
-
- defer eps[i].Close()
- if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
}
+ defer ep.Close()
- npackets := 100000
- nports := 10000
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.injectV6Packet(payload, &header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- })
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled
- // by the same socket.
- if ports[port] != ep {
- t.Fatalf("Port mismatch")
- }
- }
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
}
- if len(stats) != len(eps) {
- t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
}
- // Check that a packet distribution is fair between sockets.
- for _, c := range stats {
- n := float64(npackets) / float64(len(eps))
- // The deviation is less than 10%.
- if math.Abs(float64(c)-n) > n/10 {
- t.Fatal(c, n)
- }
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
}
}
-// testRead sends a packet of the given test flow into the stack by injecting it
-// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// testReadInternal sends a packet of the given test flow into the stack by
+// injecting it into the link endpoint. It then attempts to read it from the
+// UDP endpoint and depending on if this was expected to succeed verifies its
+// correctness.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
c.t.Helper()
payload := newPayload()
@@ -566,6 +557,9 @@ func testRead(c *testContext, flow testFlow) {
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+
var addr tcpip.FullAddress
v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
@@ -573,25 +567,55 @@ func testRead(c *testContext, flow testFlow) {
select {
case <-ch:
v, _, err = c.ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for data")
+ case <-time.After(300 * time.Millisecond):
+ if packetShouldBeDropped {
+ return // expected to time out
+ }
+ c.t.Fatal("timed out waiting for data")
}
}
+ if expectReadError && err != nil {
+ c.checkEndpointReadStats(1, epstats, err)
+ return
+ }
+
+ if err != nil {
+ c.t.Fatal("Read failed:", err)
+ }
+
+ if packetShouldBeDropped {
+ c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ }
+
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr)
+ c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
- c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
+ c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+ c.checkEndpointReadStats(1, epstats, err)
+}
+
+// testRead sends a packet of the given test flow into the stack by injecting it
+// into the link endpoint. It then reads it from the UDP endpoint and verifies
+// its correctness.
+func testRead(c *testContext, flow testFlow) {
+ c.t.Helper()
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+}
+
+// testFailingRead sends a packet of the given test flow into the stack by
+// injecting it into the link endpoint. It then tries to read it from the UDP
+// endpoint and expects this to fail.
+func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
+ c.t.Helper()
+ testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
}
func TestBindEphemeralPort(t *testing.T) {
@@ -763,13 +787,17 @@ func TestReadOnBoundToMulticast(t *testing.T) {
c.t.Fatal("SetSockOpt failed:", err)
}
+ // Check that we receive multicast packets but not unicast or broadcast
+ // ones.
testRead(c, flow)
+ testFailingRead(c, broadcast, false /* expectReadError */)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
})
}
}
// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
-// address and receive broadcast data on it.
+// address and can receive only broadcast data.
func TestV4ReadOnBoundToBroadcast(t *testing.T) {
for _, flow := range []testFlow{broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -784,8 +812,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
c.t.Fatalf("Bind failed: %s", err)
}
- // Test acceptance.
+ // Check that we receive broadcast packets but not unicast ones.
testRead(c, flow)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
+ })
+ }
+}
+
+// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
+// and receive broadcast and unicast data.
+func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s (", err)
+ }
+
+ // Check that we receive both broadcast and unicast packets.
+ testRead(c, flow)
+ testRead(c, unicastV4)
})
}
}
@@ -794,7 +845,8 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
// and verifies it fails with the provided error code.
func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
c.t.Helper()
-
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
h := flow.header4Tuple(outgoing)
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
@@ -802,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
_, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
+ c.checkEndpointWriteStats(1, epstats, gotErr)
if gotErr != wantErr {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
@@ -827,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec
func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
writeOpts := tcpip.WriteOptions{}
if setDest {
@@ -844,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
-
+ c.checkEndpointWriteStats(1, epstats, err)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -913,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Write to V4 mapped address.
testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
+ const want = 1
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
+ c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
+ }
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
@@ -1175,6 +1234,109 @@ func TestTTL(t *testing.T) {
}
}
+func TestSetTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
+ }
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ep.Close()
+
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+ })
+ }
+}
+
+func TestTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tos = 0xC0
+ var v tcpip.IPv4TOSOption
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ }
+
+ if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+ c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ }
+
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv4TOSOption(tos); v != want {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testWrite(c, flow, checker.TOS(tos, 0))
+ })
+ }
+}
+
+func TestTOSV6(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tos = 0xC0
+ var v tcpip.IPv6TrafficClassOption
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ }
+
+ if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+ c.t.Errorf("SetSockOpt failed: %s", err)
+ }
+
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testWrite(c, flow, checker.TOS(tos, 0))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1238,3 +1400,267 @@ func TestMulticastInterfaceOption(t *testing.T) {
})
}
}
+
+// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV4UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true, will result in a payload large enough
+ // so that the final generated IPv4 packet is larger than
+ // header.IPv4MinimumProcessableDatagramSize.
+ largePayload bool
+ }{
+ {unicastV4, true, false},
+ {unicastV4, true, true},
+ {multicastV4, false, false},
+ {multicastV4, false, true},
+ {broadcast, false, false},
+ {broadcast, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(576)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
+
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
+
+// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV6UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true will result in a payload large enough to
+ // create an IPv6 packet > header.IPv6MinimumMTU bytes.
+ largePayload bool
+ }{
+ {unicastV6, true, false},
+ {unicastV6, true, true},
+ {multicastV6, false, false},
+ {multicastV6, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(1280)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
+
+// TestIncrementMalformedPacketsReceived verifies if the malformed received
+// global and endpoint stats get incremented.
+func TestIncrementMalformedPacketsReceived(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ payload := newPayload()
+ c.t.Helper()
+ h := unicastV6.header4Tuple(incoming)
+ c.injectV6Packet(payload, &h, false /* !valid */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownRead verifies endpoint read shutdown and error
+// stats increment on packet receive.
+func TestShutdownRead(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingRead(c, unicastV6, true /* expectReadError */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownWrite verifies endpoint write shutdown and error
+// stats increment on packet write.
+func TestShutdownWrite(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
+}
+
+func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil:
+ want.PacketsSent.IncrementBy(incr)
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ want.WriteErrors.InvalidArgs.IncrementBy(incr)
+ case tcpip.ErrClosedForSend:
+ want.WriteErrors.WriteClosed.IncrementBy(incr)
+ case tcpip.ErrInvalidEndpointState:
+ want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
+ case tcpip.ErrNoLinkAddress:
+ want.SendErrors.NoLinkAddr.IncrementBy(incr)
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ want.SendErrors.NoRoute.IncrementBy(incr)
+ default:
+ want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}
+
+func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil, tcpip.ErrWouldBlock:
+ case tcpip.ErrClosedForReceive:
+ want.ReadErrors.ReadClosed.IncrementBy(incr)
+ default:
+ c.t.Errorf("Endpoint error missing stats update err %v", err)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
index 98d51cc69..6afdb29b7 100644
--- a/pkg/tmutex/BUILD
+++ b/pkg/tmutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index cbd92fc05..8f6f180e5 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
index b7f505a84..b6bbb0ea2 100644
--- a/pkg/urpc/BUILD
+++ b/pkg/urpc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 9173dfd0f..1f7efb064 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -1,7 +1,8 @@
-package(licenses = ["notice"])
-
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
go_template_instance(
name = "waiter_list",
diff --git a/runsc/BUILD b/runsc/BUILD
index 6b8c92706..e4e8e64a3 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -1,7 +1,7 @@
package(licenses = ["notice"]) # Apache 2.0
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_deb", "pkg_tar")
+load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar")
go_binary(
name = "runsc",
@@ -13,9 +13,10 @@ go_binary(
visibility = [
"//visibility:public",
],
- x_defs = {"main.version": "{VERSION}"},
+ x_defs = {"main.version": "{STABLE_VERSION}"},
deps = [
"//pkg/log",
+ "//pkg/refs",
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
@@ -45,9 +46,10 @@ go_binary(
visibility = [
"//visibility:public",
],
- x_defs = {"main.version": "{VERSION}"},
+ x_defs = {"main.version": "{STABLE_VERSION}"},
deps = [
"//pkg/log",
+ "//pkg/refs",
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
@@ -65,19 +67,10 @@ pkg_tar(
)
pkg_tar(
- name = "runsc-tools",
- srcs = ["//runsc/tools/dockercfg"],
- mode = "0755",
- package_dir = "/usr/libexec/runsc",
- strip_prefix = "/runsc/tools/dockercfg/linux_amd64_stripped",
-)
-
-pkg_tar(
name = "debian-data",
extension = "tar.gz",
deps = [
":runsc-bin",
- ":runsc-tools",
],
)
@@ -98,13 +91,15 @@ pkg_deb(
maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>",
package = "runsc",
postinst = "debian/postinst.sh",
- tags = [
- # TODO(b/135475885): pkg_deb requires python2:
- # https://github.com/bazelbuild/bazel/issues/8443
- "manual",
- ],
version_file = ":version.txt",
visibility = [
"//visibility:public",
],
)
+
+sh_test(
+ name = "version_test",
+ size = "small",
+ srcs = ["version_test.sh"],
+ data = [":runsc"],
+)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 588bb8851..6fe2b57de 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -57,10 +57,10 @@ go_library(
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
"//pkg/sentry/sighandling",
- "//pkg/sentry/socket/epsocket",
"//pkg/sentry/socket/hostinet",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/socket/netlink/route",
+ "//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix",
"//pkg/sentry/state",
"//pkg/sentry/strace",
@@ -80,6 +80,7 @@ go_library(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/urpc",
@@ -109,6 +110,7 @@ go_test(
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 7ae0dd05d..38278d0a2 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -19,6 +19,7 @@ import (
"strconv"
"strings"
+ "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
)
@@ -112,6 +113,34 @@ func MakeWatchdogAction(s string) (watchdog.Action, error) {
}
}
+// MakeRefsLeakMode converts type from string.
+func MakeRefsLeakMode(s string) (refs.LeakMode, error) {
+ switch strings.ToLower(s) {
+ case "disabled":
+ return refs.NoLeakChecking, nil
+ case "log-names":
+ return refs.LeaksLogWarning, nil
+ case "log-traces":
+ return refs.LeaksLogTraces, nil
+ default:
+ return 0, fmt.Errorf("invalid refs leakmode %q", s)
+ }
+}
+
+func refsLeakModeToString(mode refs.LeakMode) string {
+ switch mode {
+ // If not set, default it to disabled.
+ case refs.UninitializedLeakChecking, refs.NoLeakChecking:
+ return "disabled"
+ case refs.LeaksLogWarning:
+ return "log-names"
+ case refs.LeaksLogTraces:
+ return "log-traces"
+ default:
+ panic(fmt.Sprintf("Invalid leakmode: %d", mode))
+ }
+}
+
// Config holds configuration that is not part of the runtime spec.
type Config struct {
// RootDir is the runtime root directory.
@@ -138,6 +167,9 @@ type Config struct {
// Overlay is whether to wrap the root filesystem in an overlay.
Overlay bool
+ // FSGoferHostUDS enables the gofer to mount a host UDS.
+ FSGoferHostUDS bool
+
// Network indicates what type of network to use.
Network NetworkType
@@ -182,12 +214,6 @@ type Config struct {
// RestoreFile is the path to the saved container image
RestoreFile string
- // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
- // tests. It allows runsc to start the sandbox process as the current
- // user, and without chrooting the sandbox process. This can be
- // necessary in test environments that have limited capabilities.
- TestOnlyAllowRunAsCurrentUserWithoutChroot bool
-
// NumNetworkChannels controls the number of AF_PACKET sockets that map
// to the same underlying network device. This allows netstack to better
// scale for high throughput use cases.
@@ -201,6 +227,22 @@ type Config struct {
// AlsoLogToStderr allows to send log messages to stderr.
AlsoLogToStderr bool
+
+ // ReferenceLeakMode sets reference leak check mode
+ ReferenceLeakMode refs.LeakMode
+
+ // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
+ // tests. It allows runsc to start the sandbox process as the current
+ // user, and without chrooting the sandbox process. This can be
+ // necessary in test environments that have limited capabilities.
+ TestOnlyAllowRunAsCurrentUserWithoutChroot bool
+
+ // TestOnlyTestNameEnv should only be used in tests. It looks up for the
+ // test name in the container environment variables and adds it to the debug
+ // log file name. This is done to help identify the log with the test when
+ // multiple tests are run in parallel, since there is no way to pass
+ // parameters to the runtime from docker.
+ TestOnlyTestNameEnv string
}
// ToFlags returns a slice of flags that correspond to the given Config.
@@ -214,6 +256,7 @@ func (c *Config) ToFlags() []string {
"--debug-log-format=" + c.DebugLogFormat,
"--file-access=" + c.FileAccess.String(),
"--overlay=" + strconv.FormatBool(c.Overlay),
+ "--fsgofer-host-uds=" + strconv.FormatBool(c.FSGoferHostUDS),
"--network=" + c.Network.String(),
"--log-packets=" + strconv.FormatBool(c.LogPackets),
"--platform=" + c.Platform,
@@ -227,10 +270,14 @@ func (c *Config) ToFlags() []string {
"--num-network-channels=" + strconv.Itoa(c.NumNetworkChannels),
"--rootless=" + strconv.FormatBool(c.Rootless),
"--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr),
+ "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode),
}
+ // Only include these if set since it is never to be used by users.
if c.TestOnlyAllowRunAsCurrentUserWithoutChroot {
- // Only include if set since it is never to be used by users.
- f = append(f, "-TESTONLY-unsafe-nonroot=true")
+ f = append(f, "--TESTONLY-unsafe-nonroot=true")
+ }
+ if len(c.TestOnlyTestNameEnv) != 0 {
+ f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv)
}
return f
}
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 72cbabd16..a73c593ea 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -18,7 +18,6 @@ import (
"errors"
"fmt"
"os"
- "path"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -27,7 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/state"
"gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
@@ -142,7 +141,7 @@ func newController(fd int, l *Loader) (*controller, error) {
}
srv.Register(manager)
- if eps, ok := l.k.NetworkStack().(*epsocket.Stack); ok {
+ if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok {
net := &Network{
Stack: eps.Stack,
}
@@ -234,13 +233,6 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error {
if args.CID == "" {
return errors.New("start argument missing container ID")
}
- // Prevent CIDs containing ".." from confusing the sentry when creating
- // /containers/<cid> directory.
- // TODO(b/129293409): Once we have multiple independent roots, this
- // check won't be necessary.
- if path.Clean(args.CID) != args.CID {
- return fmt.Errorf("container ID shouldn't contain directory traversals such as \"..\": %q", args.CID)
- }
if len(args.FilePayload.Files) < 4 {
return fmt.Errorf("start arguments must contain stdin, stderr, and stdout followed by at least one file for the container root gofer")
}
@@ -355,7 +347,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
fs.SetRestoreEnvironment(*renv)
// Prepare to load from the state file.
- if eps, ok := networkStack.(*epsocket.Stack); ok {
+ if eps, ok := networkStack.(*netstack.Stack); ok {
stack.StackFromEnv = eps.Stack // FIXME(b/36201077)
}
info, err := specFile.Stat()
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 7ca776b3a..a2ecc6bcb 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -88,14 +88,24 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
seccomp.AllowAny{},
seccomp.AllowAny{},
- seccomp.AllowValue(0),
},
{
seccomp.AllowAny{},
seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
seccomp.AllowAny{},
+ },
+ // Non-private variants are included for flipcall support. They are otherwise
+ // unncessary, as the sentry will use only private futexes internally.
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
seccomp.AllowAny{},
- seccomp.AllowValue(0),
},
},
syscall.SYS_GETPID: {},
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index b6eeacf98..393c2a88b 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -25,19 +25,21 @@ import (
// Include filesystem types that OCI spec might mount.
_ "gvisor.dev/gvisor/pkg/sentry/fs/dev"
- "gvisor.dev/gvisor/pkg/sentry/fs/gofer"
_ "gvisor.dev/gvisor/pkg/sentry/fs/host"
_ "gvisor.dev/gvisor/pkg/sentry/fs/proc"
- "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/sys"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tty"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/gofer"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -62,6 +64,9 @@ const (
nonefs = "none"
)
+// tmpfs has some extra supported options that we must pass through.
+var tmpfsAllowedOptions = []string{"mode", "uid", "gid"}
+
func addOverlay(ctx context.Context, conf *Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) {
// Upper layer uses the same flags as lower, but it must be read-write.
upperFlags := lowerFlags
@@ -170,27 +175,25 @@ func p9MountOptions(fd int, fa FileAccessType) []string {
func parseAndFilterOptions(opts []string, allowedKeys ...string) ([]string, error) {
var out []string
for _, o := range opts {
- kv := strings.Split(o, "=")
- switch len(kv) {
- case 1:
- if specutils.ContainsStr(allowedKeys, o) {
- out = append(out, o)
- continue
- }
- log.Warningf("ignoring unsupported key %q", kv)
- case 2:
- if specutils.ContainsStr(allowedKeys, kv[0]) {
- out = append(out, o)
- continue
- }
- log.Warningf("ignoring unsupported key %q", kv[0])
- default:
- return nil, fmt.Errorf("invalid option %q", o)
+ ok, err := parseMountOption(o, allowedKeys...)
+ if err != nil {
+ return nil, err
+ }
+ if ok {
+ out = append(out, o)
}
}
return out, nil
}
+func parseMountOption(opt string, allowedKeys ...string) (bool, error) {
+ kv := strings.SplitN(opt, "=", 3)
+ if len(kv) > 2 {
+ return false, fmt.Errorf("invalid option %q", opt)
+ }
+ return specutils.ContainsStr(allowedKeys, kv[0]), nil
+}
+
// mountDevice returns a device string based on the fs type and target
// of the mount.
func mountDevice(m specs.Mount) string {
@@ -205,6 +208,8 @@ func mountDevice(m specs.Mount) string {
func mountFlags(opts []string) fs.MountSourceFlags {
mf := fs.MountSourceFlags{}
+ // Note: changes to supported options must be reflected in
+ // isSupportedMountFlag() as well.
for _, o := range opts {
switch o {
case "rw":
@@ -222,6 +227,18 @@ func mountFlags(opts []string) fs.MountSourceFlags {
return mf
}
+func isSupportedMountFlag(fstype, opt string) bool {
+ switch opt {
+ case "rw", "ro", "noatime", "noexec":
+ return true
+ }
+ if fstype == tmpfs {
+ ok, err := parseMountOption(opt, tmpfsAllowedOptions...)
+ return ok && err == nil
+ }
+ return false
+}
+
func mustFindFilesystem(name string) fs.Filesystem {
fs, ok := fs.FindFilesystem(name)
if !ok {
@@ -261,6 +278,18 @@ func subtargets(root string, mnts []specs.Mount) []string {
return targets
}
+func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ mns, err := mntr.setupFS(conf, procArgs)
+ if err != nil {
+ return err
+ }
+
+ // Set namespace here so that it can be found in ctx.
+ procArgs.MountNamespace = mns
+
+ return setExecutablePath(ctx, procArgs)
+}
+
// setExecutablePath sets the procArgs.Filename by searching the PATH for an
// executable matching the procArgs.Argv[0].
func setExecutablePath(ctx context.Context, procArgs *kernel.CreateProcessArgs) error {
@@ -413,6 +442,39 @@ func (m *mountHint) isSupported() bool {
return m.mount.Type == tmpfs && m.share == pod
}
+// checkCompatible verifies that shared mount is compatible with master.
+// For now enforce that all options are the same. Once bind mount is properly
+// supported, then we should ensure the master is less restrictive than the
+// container, e.g. master can be 'rw' while container mounts as 'ro'.
+func (m *mountHint) checkCompatible(mount specs.Mount) error {
+ // Remove options that don't affect to mount's behavior.
+ masterOpts := filterUnsupportedOptions(m.mount)
+ slaveOpts := filterUnsupportedOptions(mount)
+
+ if len(masterOpts) != len(slaveOpts) {
+ return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts)
+ }
+
+ sort.Strings(masterOpts)
+ sort.Strings(slaveOpts)
+ for i, opt := range masterOpts {
+ if opt != slaveOpts[i] {
+ return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts)
+ }
+ }
+ return nil
+}
+
+func filterUnsupportedOptions(mount specs.Mount) []string {
+ rv := make([]string, 0, len(mount.Options))
+ for _, o := range mount.Options {
+ if isSupportedMountFlag(mount.Type, o) {
+ rv = append(rv, o)
+ }
+ }
+ return rv
+}
+
// podMountHints contains a collection of mountHints for the pod.
type podMountHints struct {
mounts map[string]*mountHint
@@ -500,73 +562,95 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin
}
}
-// setupChildContainer is used to set up the file system for non-root containers
-// and amend the procArgs accordingly. This is the main entry point for this
-// rest of functions in this file. procArgs are passed by reference and the
-// FDMap field is modified. It dups stdioFDs.
-func (c *containerMounter) setupChildContainer(conf *Config, procArgs *kernel.CreateProcessArgs) error {
- // Setup a child container.
- log.Infof("Creating new process in child container.")
-
- // Create a new root inode and mount namespace for the container.
- rootCtx := c.k.SupervisorContext()
- rootInode, err := c.createRootMount(rootCtx, conf)
- if err != nil {
- return fmt.Errorf("creating filesystem for container: %v", err)
+// processHints processes annotations that container hints about how volumes
+// should be mounted (e.g. a volume shared between containers). It must be
+// called for the root container only.
+func (c *containerMounter) processHints(conf *Config) error {
+ ctx := c.k.SupervisorContext()
+ for _, hint := range c.hints.mounts {
+ log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
+ inode, err := c.mountSharedMaster(ctx, conf, hint)
+ if err != nil {
+ return fmt.Errorf("mounting shared master %q: %v", hint.name, err)
+ }
+ hint.root = inode
}
- mns, err := fs.NewMountNamespace(rootCtx, rootInode)
+ return nil
+}
+
+// setupFS is used to set up the file system for all containers. This is the
+// main entry point method, with most of the other being internal only. It
+// returns the mount namespace that is created for the container.
+func (c *containerMounter) setupFS(conf *Config, procArgs *kernel.CreateProcessArgs) (*fs.MountNamespace, error) {
+ log.Infof("Configuring container's file system")
+
+ // Create context with root credentials to mount the filesystem (the current
+ // user may not be privileged enough).
+ rootProcArgs := *procArgs
+ rootProcArgs.WorkingDirectory = "/"
+ rootProcArgs.Credentials = auth.NewRootCredentials(procArgs.Credentials.UserNamespace)
+ rootProcArgs.Umask = 0022
+ rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
+ rootCtx := rootProcArgs.NewContext(c.k)
+
+ mns, err := c.createMountNamespace(rootCtx, conf)
if err != nil {
- return fmt.Errorf("creating new mount namespace for container: %v", err)
+ return nil, err
}
- procArgs.MountNamespace = mns
- root := mns.Root()
- defer root.DecRef()
- // Mount all submounts.
- if err := c.mountSubmounts(rootCtx, conf, mns, root); err != nil {
- return err
+ // Set namespace here so that it can be found in rootCtx.
+ rootProcArgs.MountNamespace = mns
+
+ if err := c.mountSubmounts(rootCtx, conf, mns); err != nil {
+ return nil, err
}
- return c.checkDispenser()
+ return mns, nil
}
-func (c *containerMounter) checkDispenser() error {
- if !c.fds.empty() {
- return fmt.Errorf("not all gofer FDs were consumed, remaining: %v", c.fds)
+func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Config) (*fs.MountNamespace, error) {
+ rootInode, err := c.createRootMount(ctx, conf)
+ if err != nil {
+ return nil, fmt.Errorf("creating filesystem for container: %v", err)
}
- return nil
+ mns, err := fs.NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ return nil, fmt.Errorf("creating new mount namespace for container: %v", err)
+ }
+ return mns, nil
}
-// setupRootContainer creates a mount namespace containing the root filesystem
-// and all mounts. 'rootCtx' is used to walk directories to find mount points.
-// The 'setMountNS' callback is called after the mount namespace is created and
-// will get a reference on that namespace. The callback must ensure that the
-// rootCtx has the provided mount namespace.
-func (c *containerMounter) setupRootContainer(userCtx context.Context, rootCtx context.Context, conf *Config, setMountNS func(*fs.MountNamespace)) error {
- for _, hint := range c.hints.mounts {
- log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
- inode, err := c.mountSharedMaster(rootCtx, conf, hint)
- if err != nil {
- return fmt.Errorf("mounting shared master %q: %v", hint.name, err)
+func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error {
+ root := mns.Root()
+ defer root.DecRef()
+
+ for _, m := range c.mounts {
+ log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options)
+ if hint := c.hints.findMount(m); hint != nil && hint.isSupported() {
+ if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil {
+ return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, m.Destination, err)
+ }
+ } else {
+ if err := c.mountSubmount(ctx, conf, mns, root, m); err != nil {
+ return fmt.Errorf("mount submount %q: %v", m.Destination, err)
+ }
}
- hint.root = inode
}
- rootInode, err := c.createRootMount(rootCtx, conf)
- if err != nil {
- return fmt.Errorf("creating root mount: %v", err)
+ if err := c.mountTmp(ctx, conf, mns, root); err != nil {
+ return fmt.Errorf("mount submount %q: %v", "tmp", err)
}
- mns, err := fs.NewMountNamespace(userCtx, rootInode)
- if err != nil {
- return fmt.Errorf("creating root mount namespace: %v", err)
+
+ if err := c.checkDispenser(); err != nil {
+ return err
}
- setMountNS(mns)
+ return nil
+}
- root := mns.Root()
- defer root.DecRef()
- if err := c.mountSubmounts(rootCtx, conf, mns, root); err != nil {
- return fmt.Errorf("mounting submounts: %v", err)
+func (c *containerMounter) checkDispenser() error {
+ if !c.fds.empty() {
+ return fmt.Errorf("not all gofer FDs were consumed, remaining: %v", c.fds)
}
- return c.checkDispenser()
+ return nil
}
// mountSharedMaster mounts the master of a volume that is shared among
@@ -663,9 +747,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
fsName = sysfs
case tmpfs:
fsName = m.Type
-
- // tmpfs has some extra supported options that we must pass through.
- opts, err = parseAndFilterOptions(m.Options, "mode", "uid", "gid")
+ opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedOptions...)
case bind:
fd := c.fds.remove()
@@ -684,25 +766,6 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
return fsName, opts, useOverlay, err
}
-func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent) error {
- for _, m := range c.mounts {
- if hint := c.hints.findMount(m); hint != nil && hint.isSupported() {
- if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil {
- return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, m.Destination, err)
- }
- } else {
- if err := c.mountSubmount(ctx, conf, mns, root, m); err != nil {
- return fmt.Errorf("mount submount %q: %v", m.Destination, err)
- }
- }
- }
-
- if err := c.mountTmp(ctx, conf, mns, root); err != nil {
- return fmt.Errorf("mount submount %q: %v", "tmp", err)
- }
- return nil
-}
-
// mountSubmount mounts volumes inside the container's root. Because mounts may
// be readonly, a lower ramfs overlay is added to create the mount point dir.
// Another overlay is added with tmpfs on top if Config.Overlay is true.
@@ -769,17 +832,8 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns
// mountSharedSubmount binds mount to a previously mounted volume that is shared
// among containers in the same pod.
func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount specs.Mount, source *mountHint) error {
- // For now enforce that all options are the same. Once bind mount is properly
- // supported, then we should ensure the master is less restrictive than the
- // container, e.g. master can be 'rw' while container mounts as 'ro'.
- if len(mount.Options) != len(source.mount.Options) {
- return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", source.mount.Options, mount.Options)
- }
- sort.Strings(mount.Options)
- for i, opt := range mount.Options {
- if opt != source.mount.Options[i] {
- return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", source.mount.Options, mount.Options)
- }
+ if err := source.checkCompatible(mount); err != nil {
+ return err
}
maxTraversals := uint(0)
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index f91158027..c8e5e86ee 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -20,7 +20,6 @@ import (
mrand "math/rand"
"os"
"runtime"
- "strings"
"sync"
"sync/atomic"
"syscall"
@@ -33,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/memutil"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -56,6 +54,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/runsc/boot/filter"
@@ -63,10 +62,10 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
// Include supported socket providers.
- "gvisor.dev/gvisor/pkg/sentry/socket/epsocket"
"gvisor.dev/gvisor/pkg/sentry/socket/hostinet"
_ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
_ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
_ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
)
@@ -527,34 +526,21 @@ func (l *Loader) run() error {
// Setup the root container file system.
l.startGoferMonitor(l.sandboxID, l.goferFDs)
+
mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints)
- if err := mntr.setupRootContainer(ctx, ctx, l.conf, func(mns *fs.MountNamespace) {
- l.rootProcArgs.MountNamespace = mns
- }); err != nil {
+ if err := mntr.processHints(l.conf); err != nil {
return err
}
-
- if err := setExecutablePath(ctx, &l.rootProcArgs); err != nil {
+ if err := setupContainerFS(ctx, l.conf, mntr, &l.rootProcArgs); err != nil {
return err
}
- // Read /etc/passwd for the user's HOME directory and set the HOME
- // environment variable as required by POSIX if it is not overridden by
- // the user.
- hasHomeEnvv := false
- for _, envv := range l.rootProcArgs.Envv {
- if strings.HasPrefix(envv, "HOME=") {
- hasHomeEnvv = true
- }
- }
- if !hasHomeEnvv {
- homeDir, err := getExecUserHome(ctx, l.rootProcArgs.MountNamespace, uint32(l.rootProcArgs.Credentials.RealKUID))
- if err != nil {
- return fmt.Errorf("error reading exec user: %v", err)
- }
-
- l.rootProcArgs.Envv = append(l.rootProcArgs.Envv, "HOME="+homeDir)
+ // Add the HOME enviroment variable if it is not already set.
+ envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ if err != nil {
+ return err
}
+ l.rootProcArgs.Envv = envv
// Create the root container init task. It will begin running
// when the kernel is started.
@@ -687,13 +673,10 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file
// Setup the child container file system.
l.startGoferMonitor(cid, goferFDs)
- mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints)
- if err := mntr.setupChildContainer(conf, &procArgs); err != nil {
- return fmt.Errorf("configuring container FS: %v", err)
- }
- if err := setExecutablePath(ctx, &procArgs); err != nil {
- return fmt.Errorf("setting executable path for %+v: %v", procArgs, err)
+ mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints)
+ if err := setupContainerFS(ctx, conf, mntr, &procArgs); err != nil {
+ return err
}
// Create and start the new process.
@@ -766,26 +749,34 @@ func (l *Loader) destroyContainer(cid string) error {
if err := l.signalAllProcesses(cid, int32(linux.SIGKILL)); err != nil {
return fmt.Errorf("sending SIGKILL to all container processes: %v", err)
}
+ // Wait for all processes that belong to the container to exit (including
+ // exec'd processes).
+ for _, t := range l.k.TaskSet().Root.Tasks() {
+ if t.ContainerID() == cid {
+ t.ThreadGroup().WaitExited()
+ }
+ }
+
+ // At this point, all processes inside of the container have exited,
+ // releasing all references to the container's MountNamespace and
+ // causing all submounts and overlays to be unmounted.
+ //
+ // Since the container's MountNamespace has been released,
+ // MountNamespace.destroy() will have executed, but that function may
+ // trigger async close operations. We must wait for those to complete
+ // before returning, otherwise the caller may kill the gofer before
+ // they complete, causing a cascade of failing RPCs.
+ fs.AsyncBarrier()
}
- // Remove all container thread groups from the map.
+ // No more failure from this point on. Remove all container thread groups
+ // from the map.
for key := range l.processes {
if key.cid == cid {
delete(l.processes, key)
}
}
- // At this point, all processes inside of the container have exited,
- // releasing all references to the container's MountNamespace and
- // causing all submounts and overlays to be unmounted.
- //
- // Since the container's MountNamespace has been released,
- // MountNamespace.destroy() will have executed, but that function may
- // trigger async close operations. We must wait for those to complete
- // before returning, otherwise the caller may kill the gofer before
- // they complete, causing a cascade of failing RPCs.
- fs.AsyncBarrier()
-
log.Debugf("Container destroyed %q", cid)
return nil
}
@@ -813,6 +804,16 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
})
defer args.MountNamespace.DecRef()
+ // Add the HOME enviroment varible if it is not already set.
+ root := args.MountNamespace.Root()
+ defer root.DecRef()
+ ctx := fs.WithRoot(l.k.SupervisorContext(), root)
+ envv, err := maybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
+
// Start the process.
proc := control.Proc{Kernel: l.k}
args.PIDNamespace = tg.PIDNamespace()
@@ -911,15 +912,17 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
case NetworkNone, NetworkSandbox:
// NetworkNone sets up loopback using netstack.
- netProtos := []string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}
- protoNames := []string{tcp.ProtocolName, udp.ProtocolName, icmp.ProtocolName4}
- s := epsocket.Stack{stack.New(netProtos, protoNames, stack.Options{
- Clock: clock,
- Stats: epsocket.Metrics,
- HandleLocal: true,
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := netstack.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: netstack.Metrics,
+ HandleLocal: true,
// Enable raw sockets for users with sufficient
// privileges.
- Raw: true,
+ UnassociatedFactory: raw.EndpointFactory{},
})}
// Enable SACK Recovery.
@@ -927,6 +930,10 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
return nil, fmt.Errorf("failed to enable SACK: %v", err)
}
+ // Set default TTLs as required by socket/netstack.
+ s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+
// Enable Receive Buffer Auto-Tuning.
if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
@@ -1043,21 +1050,8 @@ func (l *Loader) signalAllProcesses(cid string, signo int32) error {
// the signal is delivered. This prevents process leaks when SIGKILL is
// sent to the entire container.
l.k.Pause()
- if err := l.k.SendContainerSignal(cid, &arch.SignalInfo{Signo: signo}); err != nil {
- l.k.Unpause()
- return err
- }
- l.k.Unpause()
-
- // If SIGKILLing all processes, wait for them to exit.
- if linux.Signal(signo) == linux.SIGKILL {
- for _, t := range l.k.TaskSet().Root.Tasks() {
- if t.ContainerID() == cid {
- t.ThreadGroup().WaitExited()
- }
- }
- }
- return nil
+ defer l.k.Unpause()
+ return l.k.SendContainerSignal(cid, &arch.SignalInfo{Signo: signo})
}
// threadGroupFromID same as threadGroupFromIDLocked except that it acquires
@@ -1090,8 +1084,3 @@ func (l *Loader) threadGroupFromIDLocked(key execID) (*kernel.ThreadGroup, *host
}
return ep.tg, ep.tty, true, nil
}
-
-func init() {
- // TODO(gvisor.dev/issue/365): Make this configurable.
- refs.SetLeakMode(refs.NoLeakChecking)
-}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index e0e32b9d5..147ff7703 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -401,17 +401,16 @@ func TestCreateMountNamespace(t *testing.T) {
}
defer cleanup()
- // setupRootContainer needs to find root from the context after the
- // namespace is created.
- var mns *fs.MountNamespace
- setMountNS := func(m *fs.MountNamespace) {
- mns = m
- ctx.(*contexttest.TestContext).RegisterValue(fs.CtxRoot, mns.Root())
- }
mntr := newContainerMounter(&tc.spec, []int{sandEnd}, nil, &podMountHints{})
- if err := mntr.setupRootContainer(ctx, ctx, conf, setMountNS); err != nil {
- t.Fatalf("createMountNamespace test case %q failed: %v", tc.name, err)
+ mns, err := mntr.createMountNamespace(ctx, conf)
+ if err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
}
+ ctx = fs.WithRoot(ctx, mns.Root())
+ if err := mntr.mountSubmounts(ctx, conf, mns); err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
+ }
+
root := mns.Root()
defer root.DecRef()
for _, p := range tc.expectedPaths {
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index ea0d9f790..32cba5ac1 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -121,10 +121,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
nicID++
nicids[link.Name] = nicID
- linkEP := loopback.New()
+ ep := loopback.New()
log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil {
return err
}
@@ -156,7 +156,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
mac := tcpip.LinkAddress(link.LinkAddress)
- linkEP, err := fdbased.New(&fdbased.Options{
+ ep, err := fdbased.New(&fdbased.Options{
FDs: FDs,
MTU: uint32(link.MTU),
EthernetHeader: true,
@@ -170,7 +170,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil {
return err
}
@@ -203,14 +203,14 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
// createNICWithAddrs creates a NIC in the network stack and adds the given
// addresses.
-func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, addrs []net.IP, loopback bool) error {
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error {
if loopback {
- if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(linkEP)); err != nil {
- return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, linkEP, err)
+ if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil {
+ return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err)
}
} else {
- if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(linkEP)); err != nil {
- return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, linkEP, err)
+ if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil {
+ return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err)
}
}
diff --git a/runsc/boot/user.go b/runsc/boot/user.go
index d1d423a5c..56cc12ee0 100644
--- a/runsc/boot/user.go
+++ b/runsc/boot/user.go
@@ -16,6 +16,7 @@ package boot
import (
"bufio"
+ "fmt"
"io"
"strconv"
"strings"
@@ -23,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
@@ -42,7 +44,7 @@ func (r *fileReader) Read(buf []byte) (int, error) {
// getExecUserHome returns the home directory of the executing user read from
// /etc/passwd as read from the container filesystem.
-func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32) (string, error) {
+func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.KUID) (string, error) {
// The default user home directory to return if no user matching the user
// if found in the /etc/passwd found in the image.
const defaultHome = "/"
@@ -82,7 +84,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
File: f,
}
- homeDir, err := findHomeInPasswd(uid, r, defaultHome)
+ homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
if err != nil {
return "", err
}
@@ -90,6 +92,28 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
return homeDir, nil
}
+// maybeAddExecUserHome returns a new slice with the HOME enviroment variable
+// set if the slice does not already contain it, otherwise it returns the
+// original slice unmodified.
+func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHome(ctx, mns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+ return append(envv, "HOME="+homeDir), nil
+}
+
// findHomeInPasswd parses a passwd file and returns the given user's home
// directory. This function does it's best to replicate the runc's behavior.
func findHomeInPasswd(uid uint32, passwd io.Reader, defaultHome string) (string, error) {
diff --git a/runsc/boot/user_test.go b/runsc/boot/user_test.go
index 01f666507..9aee2ad07 100644
--- a/runsc/boot/user_test.go
+++ b/runsc/boot/user_test.go
@@ -25,6 +25,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
func setupTempDir() (string, error) {
@@ -68,7 +69,7 @@ func setupPasswd(contents string, perms os.FileMode) func() (string, error) {
// TestGetExecUserHome tests the getExecUserHome function.
func TestGetExecUserHome(t *testing.T) {
tests := map[string]struct {
- uid uint32
+ uid auth.KUID
createRoot func() (string, error)
expected string
}{
@@ -164,13 +165,13 @@ func TestGetExecUserHome(t *testing.T) {
},
}
- var mns *fs.MountNamespace
- setMountNS := func(m *fs.MountNamespace) {
- mns = m
- ctx.(*contexttest.TestContext).RegisterValue(fs.CtxRoot, mns.Root())
- }
mntr := newContainerMounter(spec, []int{sandEnd}, nil, &podMountHints{})
- if err := mntr.setupRootContainer(ctx, ctx, conf, setMountNS); err != nil {
+ mns, err := mntr.createMountNamespace(ctx, conf)
+ if err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
+ }
+ ctx = fs.WithRoot(ctx, mns.Root())
+ if err := mntr.mountSubmounts(ctx, conf, mns); err != nil {
t.Fatalf("failed to create mount namespace: %v", err)
}
diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD
index ab2387614..d6165f9e5 100644
--- a/runsc/cgroup/BUILD
+++ b/runsc/cgroup/BUILD
@@ -6,9 +6,7 @@ go_library(
name = "cgroup",
srcs = ["cgroup.go"],
importpath = "gvisor.dev/gvisor/runsc/cgroup",
- visibility = [
- "//runsc:__subpackages__",
- ],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/log",
"//runsc/specutils",
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 5223b9972..250845ad7 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -19,6 +19,7 @@ go_library(
"exec.go",
"gofer.go",
"help.go",
+ "install.go",
"kill.go",
"list.go",
"path.go",
@@ -81,7 +82,7 @@ go_test(
"//runsc/boot",
"//runsc/container",
"//runsc/specutils",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go
index 3ae25a257..0c27f7313 100644
--- a/runsc/cmd/capability_test.go
+++ b/runsc/cmd/capability_test.go
@@ -15,6 +15,7 @@
package cmd
import (
+ "flag"
"fmt"
"os"
"testing"
@@ -25,7 +26,7 @@ import (
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func init() {
@@ -121,6 +122,7 @@ func TestCapabilities(t *testing.T) {
}
func TestMain(m *testing.M) {
+ flag.Parse()
specutils.MaybeRunAsRoot()
os.Exit(m.Run())
}
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index e817eff77..d1e99243b 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -105,11 +105,11 @@ func (ex *Exec) SetFlags(f *flag.FlagSet) {
// Execute implements subcommands.Command.Execute. It starts a process in an
// already created container.
func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- e, id, err := ex.parseArgs(f)
+ conf := args[0].(*boot.Config)
+ e, id, err := ex.parseArgs(f, conf.EnableRaw)
if err != nil {
Fatalf("parsing process spec: %v", err)
}
- conf := args[0].(*boot.Config)
waitStatus := args[1].(*syscall.WaitStatus)
c, err := container.Load(conf.RootDir, id)
@@ -117,6 +117,9 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("loading sandbox: %v", err)
}
+ log.Debugf("Exec arguments: %+v", e)
+ log.Debugf("Exec capablities: %+v", e.Capabilities)
+
// Replace empty settings with defaults from container.
if e.WorkingDirectory == "" {
e.WorkingDirectory = c.Spec.Process.Cwd
@@ -127,15 +130,13 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("getting environment variables: %v", err)
}
}
+
if e.Capabilities == nil {
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- e.Capabilities, err = specutils.Capabilities(true /* enableRaw */, c.Spec.Process.Capabilities)
+ e.Capabilities, err = specutils.Capabilities(conf.EnableRaw, c.Spec.Process.Capabilities)
if err != nil {
Fatalf("creating capabilities: %v", err)
}
+ log.Infof("Using exec capabilities from container: %+v", e.Capabilities)
}
// containerd expects an actual process to represent the container being
@@ -282,14 +283,14 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
// parseArgs parses exec information from the command line or a JSON file
// depending on whether the --process flag was used. Returns an ExecArgs and
// the ID of the container to be used.
-func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
+func (ex *Exec) parseArgs(f *flag.FlagSet, enableRaw bool) (*control.ExecArgs, string, error) {
if ex.processPath == "" {
// Requires at least a container ID and command.
if f.NArg() < 2 {
f.Usage()
return nil, "", fmt.Errorf("both a container-id and command are required")
}
- e, err := ex.argsFromCLI(f.Args()[1:])
+ e, err := ex.argsFromCLI(f.Args()[1:], enableRaw)
return e, f.Arg(0), err
}
// Requires only the container ID.
@@ -297,11 +298,11 @@ func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
f.Usage()
return nil, "", fmt.Errorf("a container-id is required")
}
- e, err := ex.argsFromProcessFile()
+ e, err := ex.argsFromProcessFile(enableRaw)
return e, f.Arg(0), err
}
-func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
+func (ex *Exec) argsFromCLI(argv []string, enableRaw bool) (*control.ExecArgs, error) {
extraKGIDs := make([]auth.KGID, 0, len(ex.extraKGIDs))
for _, s := range ex.extraKGIDs {
kgid, err := strconv.Atoi(s)
@@ -314,7 +315,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
var caps *auth.TaskCapabilities
if len(ex.caps) > 0 {
var err error
- caps, err = capabilities(ex.caps)
+ caps, err = capabilities(ex.caps, enableRaw)
if err != nil {
return nil, fmt.Errorf("capabilities error: %v", err)
}
@@ -332,7 +333,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
}, nil
}
-func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
+func (ex *Exec) argsFromProcessFile(enableRaw bool) (*control.ExecArgs, error) {
f, err := os.Open(ex.processPath)
if err != nil {
return nil, fmt.Errorf("error opening process file: %s, %v", ex.processPath, err)
@@ -342,21 +343,21 @@ func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, fmt.Errorf("error parsing process file: %s, %v", ex.processPath, err)
}
- return argsFromProcess(&p)
+ return argsFromProcess(&p, enableRaw)
}
// argsFromProcess performs all the non-IO conversion from the Process struct
// to ExecArgs.
-func argsFromProcess(p *specs.Process) (*control.ExecArgs, error) {
+func argsFromProcess(p *specs.Process, enableRaw bool) (*control.ExecArgs, error) {
// Create capabilities.
var caps *auth.TaskCapabilities
if p.Capabilities != nil {
var err error
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- caps, err = specutils.Capabilities(true /* enableRaw */, p.Capabilities)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ caps, err = specutils.Capabilities(enableRaw, p.Capabilities)
if err != nil {
return nil, fmt.Errorf("error creating capabilities: %v", err)
}
@@ -409,7 +410,7 @@ func resolveEnvs(envs ...[]string) ([]string, error) {
// capabilities takes a list of capabilities as strings and returns an
// auth.TaskCapabilities struct with those capabilities in every capability set.
// This mimics runc's behavior.
-func capabilities(cs []string) (*auth.TaskCapabilities, error) {
+func capabilities(cs []string, enableRaw bool) (*auth.TaskCapabilities, error) {
var specCaps specs.LinuxCapabilities
for _, cap := range cs {
specCaps.Ambient = append(specCaps.Ambient, cap)
@@ -418,11 +419,11 @@ func capabilities(cs []string) (*auth.TaskCapabilities, error) {
specCaps.Inheritable = append(specCaps.Inheritable, cap)
specCaps.Permitted = append(specCaps.Permitted, cap)
}
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec requires
- // the capability to be set explicitly, while 'docker run' sets it by
- // default.
- return specutils.Capabilities(true /* enableRaw */, &specCaps)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ return specutils.Capabilities(enableRaw, &specCaps)
}
// stringSlice allows a flag to be used multiple times, where each occurrence
diff --git a/runsc/cmd/exec_test.go b/runsc/cmd/exec_test.go
index eb38a431f..a1e980d08 100644
--- a/runsc/cmd/exec_test.go
+++ b/runsc/cmd/exec_test.go
@@ -91,7 +91,7 @@ func TestCLIArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := tc.ex.argsFromCLI(tc.argv)
+ e, err := tc.ex.argsFromCLI(tc.argv, true)
if err != nil {
t.Errorf("argsFromCLI(%+v): got error: %+v", tc.ex, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
@@ -144,7 +144,7 @@ func TestJSONArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := argsFromProcess(&tc.p)
+ e, err := argsFromProcess(&tc.p, true)
if err != nil {
t.Errorf("argsFromProcess(%+v): got error: %+v", tc.p, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 9faabf494..4c2fb80bf 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -182,6 +182,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
cfg := fsgofer.Config{
ROMount: isReadonlyMount(m.Options),
PanicOnWrite: g.panicOnWrite,
+ HostUDS: conf.FSGoferHostUDS,
}
ap, err := fsgofer.NewAttachPoint(m.Destination, cfg)
if err != nil {
@@ -200,6 +201,10 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("too many FDs passed for mounts. mounts: %d, FDs: %d", mountIdx, len(g.ioFDs))
}
+ if conf.FSGoferHostUDS {
+ filter.InstallUDSFilters()
+ }
+
if err := filter.Install(); err != nil {
Fatalf("installing seccomp filters: %v", err)
}
@@ -418,7 +423,7 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro
path := filepath.Join(base, name)
if !strings.HasPrefix(path, root) {
// One cannot '..' their way out of root.
- path = root
+ base = root
continue
}
fi, err := os.Lstat(path)
diff --git a/runsc/cmd/install.go b/runsc/cmd/install.go
new file mode 100644
index 000000000..441c1db0d
--- /dev/null
+++ b/runsc/cmd/install.go
@@ -0,0 +1,210 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "path"
+
+ "flag"
+ "github.com/google/subcommands"
+)
+
+// Install implements subcommands.Command.
+type Install struct {
+ ConfigFile string
+ Runtime string
+ Experimental bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Install) Name() string {
+ return "install"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Install) Synopsis() string {
+ return "adds a runtime to docker daemon configuration"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Install) Usage() string {
+ return `install [flags] <name> [-- [args...]] -- if provided, args are passed to the runtime
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (i *Install) SetFlags(fs *flag.FlagSet) {
+ fs.StringVar(&i.ConfigFile, "config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
+ fs.StringVar(&i.Runtime, "runtime", "runsc", "runtime name")
+ fs.BoolVar(&i.Experimental, "experimental", false, "enable experimental features")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (i *Install) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Grab the name and arguments.
+ runtimeArgs := f.Args()
+
+ // Extract the executable.
+ path, err := os.Executable()
+ if err != nil {
+ log.Fatalf("Error reading current exectuable: %v", err)
+ }
+
+ // Load the configuration file.
+ c, err := readConfig(i.ConfigFile)
+ if err != nil {
+ log.Fatalf("Error reading config file %q: %v", i.ConfigFile, err)
+ }
+
+ // Add the given runtime.
+ var rts map[string]interface{}
+ if i, ok := c["runtimes"]; ok {
+ rts = i.(map[string]interface{})
+ } else {
+ rts = make(map[string]interface{})
+ c["runtimes"] = rts
+ }
+ rts[i.Runtime] = struct {
+ Path string `json:"path,omitempty"`
+ RuntimeArgs []string `json:"runtimeArgs,omitempty"`
+ }{
+ Path: path,
+ RuntimeArgs: runtimeArgs,
+ }
+
+ // Set experimental if required.
+ if i.Experimental {
+ c["experimental"] = true
+ }
+
+ // Write out the runtime.
+ if err := writeConfig(c, i.ConfigFile); err != nil {
+ log.Fatalf("Error writing config file %q: %v", i.ConfigFile, err)
+ }
+
+ // Success.
+ log.Printf("Added runtime %q with arguments %v to %q.", i.Runtime, runtimeArgs, i.ConfigFile)
+ return subcommands.ExitSuccess
+}
+
+// Uninstall implements subcommands.Command.
+type Uninstall struct {
+ ConfigFile string
+ Runtime string
+}
+
+// Name implements subcommands.Command.Name.
+func (*Uninstall) Name() string {
+ return "uninstall"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Uninstall) Synopsis() string {
+ return "removes a runtime from docker daemon configuration"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Uninstall) Usage() string {
+ return `uninstall [flags] <name>
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (u *Uninstall) SetFlags(fs *flag.FlagSet) {
+ fs.StringVar(&u.ConfigFile, "config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
+ fs.StringVar(&u.Runtime, "runtime", "runsc", "runtime name")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (u *Uninstall) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ log.Printf("Removing runtime %q from %q.", u.Runtime, u.ConfigFile)
+
+ c, err := readConfig(u.ConfigFile)
+ if err != nil {
+ log.Fatalf("Error reading config file %q: %v", u.ConfigFile, err)
+ }
+
+ var rts map[string]interface{}
+ if i, ok := c["runtimes"]; ok {
+ rts = i.(map[string]interface{})
+ } else {
+ log.Fatalf("runtime %q not found", u.Runtime)
+ }
+ if _, ok := rts[u.Runtime]; !ok {
+ log.Fatalf("runtime %q not found", u.Runtime)
+ }
+ delete(rts, u.Runtime)
+
+ if err := writeConfig(c, u.ConfigFile); err != nil {
+ log.Fatalf("Error writing config file %q: %v", u.ConfigFile, err)
+ }
+ return subcommands.ExitSuccess
+}
+
+func readConfig(path string) (map[string]interface{}, error) {
+ // Read the configuration data.
+ configBytes, err := ioutil.ReadFile(path)
+ if err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if len(configBytes) > 0 {
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return nil, err
+ }
+ }
+
+ return c, nil
+}
+
+func writeConfig(c map[string]interface{}, filename string) error {
+ // Marshal the configuration.
+ b, err := json.MarshalIndent(c, "", " ")
+ if err != nil {
+ return err
+ }
+
+ // Copy the old configuration.
+ old, err := ioutil.ReadFile(filename)
+ if err != nil {
+ if !os.IsNotExist(err) {
+ return fmt.Errorf("error reading config file %q: %v", filename, err)
+ }
+ } else {
+ if err := ioutil.WriteFile(filename+"~", old, 0644); err != nil {
+ return fmt.Errorf("error backing up config file %q: %v", filename, err)
+ }
+ }
+
+ // Make the necessary directories.
+ if err := os.MkdirAll(path.Dir(filename), 0755); err != nil {
+ return fmt.Errorf("error creating config directory for %q: %v", filename, err)
+ }
+
+ // Write the new configuration.
+ if err := ioutil.WriteFile(filename, b, 0644); err != nil {
+ return fmt.Errorf("error writing config file %q: %v", filename, err)
+ }
+
+ return nil
+}
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index de8202bb1..26d1cd5ab 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -47,6 +47,7 @@ go_test(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/kernel",
@@ -56,7 +57,7 @@ go_test(
"//runsc/boot",
"//runsc/boot/platforms",
"//runsc/specutils",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_kr_pty//:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index e9372989f..7d67c3a75 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -30,7 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// socketPath creates a path inside bundleDir and ensures that the returned
diff --git a/runsc/container/container.go b/runsc/container/container.go
index bbb364214..a721c1c31 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -513,9 +513,16 @@ func (c *Container) Start(conf *boot.Config) error {
return err
}
- // Adjust the oom_score_adj for sandbox and gofers. This must be done after
+ // Adjust the oom_score_adj for sandbox. This must be done after
// save().
- return c.adjustOOMScoreAdj(conf)
+ err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false)
+ if err != nil {
+ return err
+ }
+
+ // Set container's oom_score_adj to the gofer since it is dedicated to
+ // the container, in case the gofer uses up too much memory.
+ return c.adjustGoferOOMScoreAdj()
}
// Restore takes a container and replaces its kernel and file system
@@ -782,6 +789,9 @@ func (c *Container) Destroy() error {
}
defer unlock()
+ // Stored for later use as stop() sets c.Sandbox to nil.
+ sb := c.Sandbox
+
if err := c.stop(); err != nil {
err = fmt.Errorf("stopping container: %v", err)
log.Warningf("%v", err)
@@ -796,6 +806,16 @@ func (c *Container) Destroy() error {
c.changeStatus(Stopped)
+ // Adjust oom_score_adj for the sandbox. This must be done after the
+ // container is stopped and the directory at c.Root is removed.
+ // We must test if the sandbox is nil because Destroy should be
+ // idempotent.
+ if sb != nil {
+ if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil {
+ errs = append(errs, err.Error())
+ }
+ }
+
// "If any poststop hook fails, the runtime MUST log a warning, but the
// remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec.
// Based on the OCI, "The post-stop hooks MUST be called after the container is
@@ -926,7 +946,14 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund
}
if conf.DebugLog != "" {
- debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer")
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer", test)
if err != nil {
return nil, nil, fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
}
@@ -1139,35 +1166,82 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
return fn()
}
-// adjustOOMScoreAdj sets the oom_score_adj for the sandbox and all gofers.
+// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer.
+func (c *Container) adjustGoferOOMScoreAdj() error {
+ if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil {
+ if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
+ return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
+ }
+ }
+
+ return nil
+}
+
+// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox.
// oom_score_adj is set to the lowest oom_score_adj among the containers
// running in the sandbox.
//
// TODO(gvisor.dev/issue/512): This call could race with other containers being
// created at the same time and end up setting the wrong oom_score_adj to the
// sandbox.
-func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error {
- // If this container's OOMScoreAdj is nil then we can exit early as no
- // change should be made to oom_score_adj for the sandbox.
- if c.Spec.Process.OOMScoreAdj == nil {
- return nil
- }
-
- containers, err := loadSandbox(conf.RootDir, c.Sandbox.ID)
+func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
+ containers, err := loadSandbox(rootDir, s.ID)
if err != nil {
return fmt.Errorf("loading sandbox containers: %v", err)
}
+ // Do nothing if the sandbox has been terminated.
+ if len(containers) == 0 {
+ return nil
+ }
+
// Get the lowest score for all containers.
var lowScore int
scoreFound := false
- for _, container := range containers {
- if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ if len(containers) == 1 && len(containers[0].Spec.Annotations[specutils.ContainerdContainerTypeAnnotation]) == 0 {
+ // This is a single-container sandbox. Set the oom_score_adj to
+ // the value specified in the OCI bundle.
+ if containers[0].Spec.Process.OOMScoreAdj != nil {
scoreFound = true
- lowScore = *container.Spec.Process.OOMScoreAdj
+ lowScore = *containers[0].Spec.Process.OOMScoreAdj
+ }
+ } else {
+ for _, container := range containers {
+ // Special multi-container support for CRI. Ignore the root
+ // container when calculating oom_score_adj for the sandbox because
+ // it is the infrastructure (pause) container and always has a very
+ // low oom_score_adj.
+ //
+ // We will use OOMScoreAdj in the single-container case where the
+ // containerd container-type annotation is not present.
+ if container.Spec.Annotations[specutils.ContainerdContainerTypeAnnotation] == specutils.ContainerdContainerTypeSandbox {
+ continue
+ }
+
+ if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ scoreFound = true
+ lowScore = *container.Spec.Process.OOMScoreAdj
+ }
}
}
+ // If the container is destroyed and remaining containers have no
+ // oomScoreAdj specified then we must revert to the oom_score_adj of the
+ // parent process.
+ if !scoreFound && destroy {
+ ppid, err := specutils.GetParentPid(s.Pid)
+ if err != nil {
+ return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err)
+ }
+ pScore, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err)
+ }
+
+ scoreFound = true
+ lowScore = pScore
+ }
+
// Only set oom_score_adj if one of the containers has oom_score_adj set
// in the OCI bundle. If not, we need to inherit the parent process's
// oom_score_adj.
@@ -1177,15 +1251,10 @@ func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error {
}
// Set the lowest of all containers oom_score_adj to the sandbox.
- if err := setOOMScoreAdj(c.Sandbox.Pid, lowScore); err != nil {
- return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", c.Sandbox.ID, err)
+ if err := setOOMScoreAdj(s.Pid, lowScore); err != nil {
+ return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
}
- // Set container's oom_score_adj to the gofer since it is dedicated to the
- // container, in case the gofer uses up too much memory.
- if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
- return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
- }
return nil
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index af128bf1c..519f5ed9b 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -16,6 +16,7 @@ package container
import (
"bytes"
+ "flag"
"fmt"
"io"
"io/ioutil"
@@ -33,13 +34,14 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// waitForProcessList waits for the given process list to show up in the container.
@@ -155,12 +157,7 @@ func waitForFile(f *os.File) error {
return nil
}
- timeout := 5 * time.Second
- if testutil.RaceEnabled {
- // Race makes slow things even slow, so bump the timeout.
- timeout = 3 * timeout
- }
- return testutil.Poll(op, timeout)
+ return testutil.Poll(op, 30*time.Second)
}
// readOutputNum reads a file at given filepath and returns the int at the
@@ -254,10 +251,6 @@ func configs(opts ...configOption) []*boot.Config {
// TODO(b/112165693): KVM tests are flaky. Disable until fixed.
continue
- // TODO(b/68787993): KVM doesn't work with --race.
- if testutil.RaceEnabled {
- continue
- }
c.Platform = platforms.KVM
case nonExclusiveFS:
c.FileAccess = boot.FileAccessShared
@@ -1310,10 +1303,13 @@ func TestRunNonRoot(t *testing.T) {
t.Logf("Running test with conf: %+v", conf)
spec := testutil.NewSpecWithArgs("/bin/true")
+
+ // Set a random user/group with no access to "blocked" dir.
spec.Process.User.UID = 343
spec.Process.User.GID = 2401
+ spec.Process.Capabilities = nil
- // User that container runs as can't list '$TMP/blocked' and would fail to
+ // User running inside container can't list '$TMP/blocked' and would fail to
// mount it.
dir, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
if err != nil {
@@ -1327,6 +1323,17 @@ func TestRunNonRoot(t *testing.T) {
t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
}
+ src, err := ioutil.TempDir(testutil.TmpDir(), "src")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: dir,
+ Source: src,
+ Type: "bind",
+ })
+
if err := run(spec, conf); err != nil {
t.Fatalf("error running sandbox: %v", err)
}
@@ -1637,22 +1644,27 @@ func TestGoferExits(t *testing.T) {
}
func TestRootNotMount(t *testing.T) {
- if testutil.RaceEnabled {
- // Requires statically linked binary, since it's mapping the root to a
- // random dir, libs cannot be located.
- t.Skip("race makes test_app not statically linked")
- }
-
appSym, err := testutil.FindFile("runsc/container/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
+
app, err := filepath.EvalSymlinks(appSym)
if err != nil {
t.Fatalf("error resolving %q symlink: %v", appSym, err)
}
log.Infof("App path %q is a symlink to %q", appSym, app)
+ static, err := testutil.IsStatic(app)
+ if err != nil {
+ t.Fatalf("error reading application binary: %v", err)
+ }
+ if !static {
+ // This happens during race builds; we cannot map in shared
+ // libraries also, so we need to skip the test.
+ t.Skip()
+ }
+
root := filepath.Dir(app)
exe := "/" + filepath.Base(app)
log.Infof("Executing %q in %q", exe, root)
@@ -2038,6 +2050,30 @@ func TestMountSymlink(t *testing.T) {
}
}
+// Check that --net-raw disables the CAP_NET_RAW capability.
+func TestNetRaw(t *testing.T) {
+ capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10)
+ app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ for _, enableRaw := range []bool{true, false} {
+ conf := testutil.TestConfig()
+ conf.EnableRaw = enableRaw
+
+ test := "--enabled"
+ if !enableRaw {
+ test = "--disabled"
+ }
+
+ spec := testutil.NewSpecWithArgs(app, "capability", test, capNetRaw)
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+ }
+}
+
// executeSync synchronously executes a new process.
func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
pid, err := cont.Execute(args)
@@ -2053,10 +2089,10 @@ func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus,
func TestMain(m *testing.M) {
log.SetLevel(log.Debug)
+ flag.Parse()
if err := testutil.ConfigureExePath(); err != nil {
panic(err.Error())
}
specutils.MaybeRunAsRoot()
-
os.Exit(m.Run())
}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 2d51fecc6..9e02a825e 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -32,7 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
@@ -549,10 +549,16 @@ func TestMultiContainerDestroy(t *testing.T) {
t.Logf("Running test with conf: %+v", conf)
// First container will remain intact while the second container is killed.
- specs, ids := createSpecs(
- []string{app, "reaper"},
+ podSpecs, ids := createSpecs(
+ []string{"sleep", "100"},
[]string{app, "fork-bomb"})
- containers, cleanup, err := startContainers(conf, specs, ids)
+
+ // Run the fork bomb in a PID namespace to prevent processes to be
+ // re-parented to PID=1 in the root container.
+ podSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{{Type: "pid"}},
+ }
+ containers, cleanup, err := startContainers(conf, podSpecs, ids)
if err != nil {
t.Fatalf("error starting containers: %v", err)
}
@@ -580,7 +586,7 @@ func TestMultiContainerDestroy(t *testing.T) {
if err != nil {
t.Fatalf("error getting process data from sandbox: %v", err)
}
- expectedPL := []*control.Process{{PID: 1, Cmd: "test_app"}}
+ expectedPL := []*control.Process{{PID: 1, Cmd: "sleep"}}
if !procListsEqual(pss, expectedPL) {
t.Errorf("container got process list: %s, want: %s", procListToString(pss), procListToString(expectedPL))
}
@@ -1291,6 +1297,53 @@ func TestMultiContainerSharedMountRestart(t *testing.T) {
}
}
+// Test that unsupported pod mounts options are ignored when matching master and
+// slave mounts.
+func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) {
+ conf := testutil.TestConfig()
+ t.Logf("Running test with conf: %+v", conf)
+
+ // Setup the containers.
+ sleep := []string{"/bin/sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: []string{"rw", "rbind", "relatime"},
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ mnt1.Options = []string{"rw", "nosuid"}
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+
+ createSharedMount(mnt0, "test-mount", podSpec...)
+
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
+ desc: "directory is mounted in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
+ desc: "directory is mounted in container1",
+ },
+ }
+ if err := execMany(execs); err != nil {
+ t.Fatal(err.Error())
+ }
+}
+
// Test that one container can send an FD to another container, even though
// they have distinct MountNamespaces.
func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) {
@@ -1485,3 +1538,58 @@ func TestMultiContainerLoadSandbox(t *testing.T) {
t.Errorf("containers not found: %v", wantIDs)
}
}
+
+// TestMultiContainerRunNonRoot checks that child container can be configured
+// when running as non-privileged user.
+func TestMultiContainerRunNonRoot(t *testing.T) {
+ cmdRoot := []string{"/bin/sleep", "100"}
+ cmdSub := []string{"/bin/true"}
+ podSpecs, ids := createSpecs(cmdRoot, cmdSub)
+
+ // User running inside container can't list '$TMP/blocked' and would fail to
+ // mount it.
+ blocked, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ if err := os.Chmod(blocked, 0700); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", blocked, err)
+ }
+ dir := path.Join(blocked, "test")
+ if err := os.Mkdir(dir, 0755); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
+ }
+
+ src, err := ioutil.TempDir(testutil.TmpDir(), "src")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+
+ // Set a random user/group with no access to "blocked" dir.
+ podSpecs[1].Process.User.UID = 343
+ podSpecs[1].Process.User.GID = 2401
+ podSpecs[1].Process.Capabilities = nil
+
+ podSpecs[1].Mounts = append(podSpecs[1].Mounts, specs.Mount{
+ Destination: dir,
+ Source: src,
+ Type: "bind",
+ })
+
+ conf := testutil.TestConfig()
+ pod, cleanup, err := startContainers(conf, podSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Once all containers are started, wait for the child container to exit.
+ // This means that the volume was mounted properly.
+ ws, err := pod[1].Wait()
+ if err != nil {
+ t.Fatalf("running child container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Fatalf("child container failed, waitStatus: %v", ws)
+ }
+}
diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go
index 1f90d2462..dc4194134 100644
--- a/runsc/container/shared_volume_test.go
+++ b/runsc/container/shared_volume_test.go
@@ -25,7 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/runsc/boot"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// TestSharedVolume checks that modifications to a volume mount are propagated
diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD
index 82dbd54d2..9bf9e6e9d 100644
--- a/runsc/container/test_app/BUILD
+++ b/runsc/container/test_app/BUILD
@@ -13,7 +13,7 @@ go_binary(
visibility = ["//runsc/container:__pkg__"],
deps = [
"//pkg/unet",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_google_subcommands//:go_default_library",
],
)
diff --git a/runsc/container/test_app/fds.go b/runsc/container/test_app/fds.go
index c12809cab..a90cc1662 100644
--- a/runsc/container/test_app/fds.go
+++ b/runsc/container/test_app/fds.go
@@ -24,7 +24,7 @@ import (
"flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/unet"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
const fileContents = "foobarbaz"
@@ -60,7 +60,7 @@ func (fds *fdSender) Execute(ctx context.Context, f *flag.FlagSet, args ...inter
log.Fatalf("socket flag must be set")
}
- dir, err := ioutil.TempDir(testutil.TmpDir(), "")
+ dir, err := ioutil.TempDir("", "")
if err != nil {
log.Fatalf("TempDir failed: %v", err)
}
diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go
index 6578c7b41..913d781c6 100644
--- a/runsc/container/test_app/test_app.go
+++ b/runsc/container/test_app/test_app.go
@@ -19,22 +19,25 @@ package main
import (
"context"
"fmt"
+ "io/ioutil"
"log"
"net"
"os"
"os/exec"
+ "regexp"
"strconv"
sys "syscall"
"time"
"flag"
"github.com/google/subcommands"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func main() {
subcommands.Register(subcommands.HelpCommand(), "")
subcommands.Register(subcommands.FlagsCommand(), "")
+ subcommands.Register(new(capability), "")
subcommands.Register(new(fdReceiver), "")
subcommands.Register(new(fdSender), "")
subcommands.Register(new(forkBomb), "")
@@ -287,3 +290,65 @@ func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interfac
}
return subcommands.ExitSuccess
}
+
+type capability struct {
+ enabled uint64
+ disabled uint64
+}
+
+// Name implements subcommands.Command.
+func (*capability) Name() string {
+ return "capability"
+}
+
+// Synopsis implements subcommands.Command.
+func (*capability) Synopsis() string {
+ return "checks if effective capabilities are set/unset"
+}
+
+// Usage implements subcommands.Command.
+func (*capability) Usage() string {
+ return "capability [--enabled=number] [--disabled=number]"
+}
+
+// SetFlags implements subcommands.Command.
+func (c *capability) SetFlags(f *flag.FlagSet) {
+ f.Uint64Var(&c.enabled, "enabled", 0, "")
+ f.Uint64Var(&c.disabled, "disabled", 0, "")
+}
+
+// Execute implements subcommands.Command.
+func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if c.enabled == 0 && c.disabled == 0 {
+ fmt.Println("One of the flags must be set")
+ return subcommands.ExitUsageError
+ }
+
+ status, err := ioutil.ReadFile("/proc/self/status")
+ if err != nil {
+ fmt.Printf("Error reading %q: %v\n", "proc/self/status", err)
+ return subcommands.ExitFailure
+ }
+ re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n")
+ matches := re.FindStringSubmatch(string(status))
+ if matches == nil || len(matches) != 2 {
+ fmt.Printf("Effective capabilities not found in\n%s\n", status)
+ return subcommands.ExitFailure
+ }
+ caps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err)
+ return subcommands.ExitFailure
+ }
+
+ if c.enabled != 0 && (caps&c.enabled) != c.enabled {
+ fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps)
+ return subcommands.ExitFailure
+ }
+ if c.disabled != 0 && (caps&c.disabled) != 0 {
+ fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps)
+ return subcommands.ExitFailure
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/criutil/BUILD b/runsc/criutil/BUILD
new file mode 100644
index 000000000..558133a0e
--- /dev/null
+++ b/runsc/criutil/BUILD
@@ -0,0 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "criutil",
+ testonly = 1,
+ srcs = ["criutil.go"],
+ importpath = "gvisor.dev/gvisor/runsc/criutil",
+ visibility = ["//:sandbox"],
+ deps = ["//runsc/testutil"],
+)
diff --git a/runsc/test/testutil/crictl.go b/runsc/criutil/criutil.go
index 4f9ee0c05..773f5a1c4 100644
--- a/runsc/test/testutil/crictl.go
+++ b/runsc/criutil/criutil.go
@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testutil
+// Package criutil contains utility functions for interacting with the
+// Container Runtime Interface (CRI), principally via the crictl command line
+// tool. This requires critools to be installed on the local system.
+package criutil
import (
"encoding/json"
@@ -21,6 +24,8 @@ import (
"os/exec"
"strings"
"time"
+
+ "gvisor.dev/gvisor/runsc/testutil"
)
const endpointPrefix = "unix://"
@@ -152,19 +157,61 @@ func (cc *Crictl) RmPod(podID string) error {
return err
}
-// StartPodAndContainer pulls an image, then starts a sandbox and container in
-// that sandbox. It returns the pod ID and container ID.
-func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) {
+// StartContainer pulls the given image ands starts the container in the
+// sandbox with the given podID.
+func (cc *Crictl) StartContainer(podID, image, sbSpec, contSpec string) (string, error) {
+ // Write the specs to files that can be read by crictl.
+ sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec)
+ if err != nil {
+ return "", fmt.Errorf("failed to write sandbox spec: %v", err)
+ }
+ contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec)
+ if err != nil {
+ return "", fmt.Errorf("failed to write container spec: %v", err)
+ }
+
+ return cc.startContainer(podID, image, sbSpecFile, contSpecFile)
+}
+
+func (cc *Crictl) startContainer(podID, image, sbSpecFile, contSpecFile string) (string, error) {
if err := cc.Pull(image); err != nil {
- return "", "", fmt.Errorf("failed to pull %s: %v", image, err)
+ return "", fmt.Errorf("failed to pull %s: %v", image, err)
+ }
+
+ contID, err := cc.Create(podID, contSpecFile, sbSpecFile)
+ if err != nil {
+ return "", fmt.Errorf("failed to create container in pod %q: %v", podID, err)
}
+ if _, err := cc.Start(contID); err != nil {
+ return "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err)
+ }
+
+ return contID, nil
+}
+
+// StopContainer stops and deletes the container with the given container ID.
+func (cc *Crictl) StopContainer(contID string) error {
+ if err := cc.Stop(contID); err != nil {
+ return fmt.Errorf("failed to stop container %q: %v", contID, err)
+ }
+
+ if err := cc.Rm(contID); err != nil {
+ return fmt.Errorf("failed to remove container %q: %v", contID, err)
+ }
+
+ return nil
+}
+
+// StartPodAndContainer pulls an image, then starts a sandbox and container in
+// that sandbox. It returns the pod ID and container ID.
+func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) {
// Write the specs to files that can be read by crictl.
- sbSpecFile, err := WriteTmpFile("sbSpec", sbSpec)
+ sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec)
if err != nil {
return "", "", fmt.Errorf("failed to write sandbox spec: %v", err)
}
- contSpecFile, err := WriteTmpFile("contSpec", contSpec)
+ contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec)
if err != nil {
return "", "", fmt.Errorf("failed to write container spec: %v", err)
}
@@ -174,28 +221,17 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string,
return "", "", err
}
- contID, err := cc.Create(podID, contSpecFile, sbSpecFile)
- if err != nil {
- return "", "", fmt.Errorf("failed to create container in pod %q: %v", podID, err)
- }
-
- if _, err := cc.Start(contID); err != nil {
- return "", "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err)
- }
+ contID, err := cc.startContainer(podID, image, sbSpecFile, contSpecFile)
- return podID, contID, nil
+ return podID, contID, err
}
// StopPodAndContainer stops a container and pod.
func (cc *Crictl) StopPodAndContainer(podID, contID string) error {
- if err := cc.Stop(contID); err != nil {
+ if err := cc.StopContainer(contID); err != nil {
return fmt.Errorf("failed to stop container %q in pod %q: %v", contID, podID, err)
}
- if err := cc.Rm(contID); err != nil {
- return fmt.Errorf("failed to remove container %q in pod %q: %v", contID, podID, err)
- }
-
if err := cc.StopPod(podID); err != nil {
return fmt.Errorf("failed to stop pod %q: %v", podID, err)
}
@@ -233,7 +269,7 @@ func (cc *Crictl) run(args ...string) (string, error) {
case err := <-errCh:
return "", err
case <-time.After(cc.timeout):
- if err := KillCommand(cmd); err != nil {
+ if err := testutil.KillCommand(cmd); err != nil {
return "", fmt.Errorf("timed out, then couldn't kill process %+v: %v", cmd, err)
}
return "", fmt.Errorf("timed out: %+v", cmd)
diff --git a/runsc/debian/postinst.sh b/runsc/debian/postinst.sh
index 03a5ff524..dc7aeee87 100755
--- a/runsc/debian/postinst.sh
+++ b/runsc/debian/postinst.sh
@@ -15,10 +15,10 @@
# limitations under the License.
if [ "$1" != configure ]; then
- exit 0
+ exit 0
fi
if [ -f /etc/docker/daemon.json ]; then
- /usr/libexec/runsc/dockercfg runtime-add runsc /usr/bin/runsc
- systemctl restart docker
+ runsc install
+ systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2
fi
diff --git a/runsc/dockerutil/BUILD b/runsc/dockerutil/BUILD
new file mode 100644
index 000000000..0e0423504
--- /dev/null
+++ b/runsc/dockerutil/BUILD
@@ -0,0 +1,15 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "dockerutil",
+ testonly = 1,
+ srcs = ["dockerutil.go"],
+ importpath = "gvisor.dev/gvisor/runsc/dockerutil",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//runsc/testutil",
+ "@com_github_kr_pty//:go_default_library",
+ ],
+)
diff --git a/runsc/test/testutil/docker.go b/runsc/dockerutil/dockerutil.go
index 3f3e191b0..57f6ae8de 100644
--- a/runsc/test/testutil/docker.go
+++ b/runsc/dockerutil/dockerutil.go
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testutil
+// Package dockerutil is a collection of utility functions, primarily for
+// testing.
+package dockerutil
import (
+ "encoding/json"
"flag"
"fmt"
"io/ioutil"
@@ -29,26 +32,13 @@ import (
"time"
"github.com/kr/pty"
+ "gvisor.dev/gvisor/runsc/testutil"
)
-var runtimeType = flag.String("runtime-type", "", "specify which runtime to use: kvm, hostnet, overlay")
-
-func getRuntime() string {
- r, ok := os.LookupEnv("RUNSC_RUNTIME")
- if !ok {
- r = "runsc-test"
- }
- if *runtimeType != "" {
- r += "-" + *runtimeType
- }
- return r
-}
-
-// IsPauseResumeSupported returns true if Pause/Resume is supported by runtime.
-func IsPauseResumeSupported() bool {
- // Native host network stack can't be saved.
- return !strings.Contains(getRuntime(), "hostnet")
-}
+var (
+ runtime = flag.String("runtime", "runsc", "specify which runtime to use")
+ config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+)
// EnsureSupportedDockerVersion checks if correct docker is installed.
func EnsureSupportedDockerVersion() {
@@ -69,6 +59,48 @@ func EnsureSupportedDockerVersion() {
}
}
+// RuntimePath returns the binary path for the current runtime.
+func RuntimePath() (string, error) {
+ // Read the configuration data; the file must exist.
+ configBytes, err := ioutil.ReadFile(*config)
+ if err != nil {
+ return "", err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return "", err
+ }
+
+ // Decode the expected configuration.
+ r, ok := c["runtimes"]
+ if !ok {
+ return "", fmt.Errorf("no runtimes declared: %v", c)
+ }
+ rs, ok := r.(map[string]interface{})
+ if !ok {
+ // The runtimes are not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ r, ok = rs[*runtime]
+ if !ok {
+ // The expected runtime is not declared.
+ return "", fmt.Errorf("runtime %q not found: %v", *runtime, c)
+ }
+ rs, ok = r.(map[string]interface{})
+ if !ok {
+ // The runtime is not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ p, ok := rs["path"].(string)
+ if !ok {
+ // The runtime does not declare a path.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ return p, nil
+}
+
// MountMode describes if the mount should be ro or rw.
type MountMode int
@@ -113,7 +145,7 @@ func PrepareFiles(names ...string) (string, error) {
for _, name := range names {
src := getLocalPath(name)
dst := path.Join(dir, name)
- if err := Copy(src, dst); err != nil {
+ if err := testutil.Copy(src, dst); err != nil {
return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
}
}
@@ -163,7 +195,10 @@ type Docker struct {
// MakeDocker sets up the struct for a Docker container.
// Names of containers will be unique.
func MakeDocker(namePrefix string) Docker {
- return Docker{Name: RandomName(namePrefix), Runtime: getRuntime()}
+ return Docker{
+ Name: testutil.RandomName(namePrefix),
+ Runtime: *runtime,
+ }
}
// logDockerID logs a container id, which is needed to find container runsc logs.
@@ -205,7 +240,7 @@ func (d *Docker) Stop() error {
// Run calls 'docker run' with the arguments provided. The container starts
// running in the background and the call returns immediately.
func (d *Docker) Run(args ...string) error {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-d"}
+ a := d.runArgs("-d")
a = append(a, args...)
_, err := do(a...)
if err == nil {
@@ -216,7 +251,7 @@ func (d *Docker) Run(args ...string) error {
// RunWithPty is like Run but with an attached pty.
func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-it"}
+ a := d.runArgs("-it")
a = append(a, args...)
return doWithPty(a...)
}
@@ -224,8 +259,7 @@ func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) {
// RunFg calls 'docker run' with the arguments provided in the foreground. It
// blocks until the container exits and returns the output.
func (d *Docker) RunFg(args ...string) (string, error) {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name}
- a = append(a, args...)
+ a := d.runArgs(args...)
out, err := do(a...)
if err == nil {
d.logDockerID()
@@ -233,6 +267,14 @@ func (d *Docker) RunFg(args ...string) (string, error) {
return string(out), err
}
+func (d *Docker) runArgs(args ...string) []string {
+ // Environment variable RUNSC_TEST_NAME is picked up by the runtime and added
+ // to the log name, so one can easily identify the corresponding logs for
+ // this test.
+ rv := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-e", "RUNSC_TEST_NAME=" + d.Name}
+ return append(rv, args...)
+}
+
// Logs calls 'docker logs'.
func (d *Docker) Logs() (string, error) {
return do("logs", d.Name)
@@ -240,7 +282,22 @@ func (d *Docker) Logs() (string, error) {
// Exec calls 'docker exec' with the arguments provided.
func (d *Docker) Exec(args ...string) (string, error) {
- a := []string{"exec", d.Name}
+ return d.ExecWithFlags(nil, args...)
+}
+
+// ExecWithFlags calls 'docker exec <flags> name <args>'.
+func (d *Docker) ExecWithFlags(flags []string, args ...string) (string, error) {
+ a := []string{"exec"}
+ a = append(a, flags...)
+ a = append(a, d.Name)
+ a = append(a, args...)
+ return do(a...)
+}
+
+// ExecAsUser calls 'docker exec' as the given user with the arguments
+// provided.
+func (d *Docker) ExecAsUser(user string, args ...string) (string, error) {
+ a := []string{"exec", "--user", user, d.Name}
a = append(a, args...)
return do(a...)
}
@@ -297,7 +354,11 @@ func (d *Docker) Remove() error {
func (d *Docker) CleanUp() {
d.logDockerID()
if _, err := do("kill", d.Name); err != nil {
- log.Printf("error killing container %q: %v", d.Name, err)
+ if strings.Contains(err.Error(), "is not running") {
+ // Nothing to kill. Don't log the error in this case.
+ } else {
+ log.Printf("error killing container %q: %v", d.Name, err)
+ }
}
if err := d.Remove(); err != nil {
log.Print(err)
diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD
index e2318a978..02168ad1b 100644
--- a/runsc/fsgofer/filter/BUILD
+++ b/runsc/fsgofer/filter/BUILD
@@ -17,6 +17,7 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/seccomp",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 8ddfa77d6..0bf7507b7 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -83,6 +83,11 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowAny{},
seccomp.AllowValue(syscall.F_GETFD),
},
+ // Used by flipcall.PacketWindowAllocator.Init().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(unix.F_ADD_SEALS),
+ },
},
syscall.SYS_FSTAT: {},
syscall.SYS_FSTATFS: {},
@@ -103,6 +108,19 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowAny{},
seccomp.AllowValue(0),
},
+ // Non-private futex used for flipcall.
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
},
syscall.SYS_GETDENTS64: {},
syscall.SYS_GETPID: {},
@@ -112,6 +130,7 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_LINKAT: {},
syscall.SYS_LSEEK: {},
syscall.SYS_MADVISE: {},
+ unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
syscall.SYS_MKDIRAT: {},
syscall.SYS_MMAP: []seccomp.Rule{
{
@@ -158,8 +177,16 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_RENAMEAT: {},
syscall.SYS_RESTART_SYSCALL: {},
syscall.SYS_RT_SIGPROCMASK: {},
+ syscall.SYS_RT_SIGRETURN: {},
syscall.SYS_SCHED_YIELD: {},
syscall.SYS_SENDMSG: []seccomp.Rule{
+ // Used by fdchannel.Endpoint.SendFD().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ // Used by unet.SocketWriter.WriteVec().
{
seccomp.AllowAny{},
seccomp.AllowAny{},
@@ -170,7 +197,15 @@ var allowedSyscalls = seccomp.SyscallRules{
{seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
- syscall.SYS_SYMLINKAT: {},
+ // Used by fdchannel.NewConnectedSockets().
+ syscall.SYS_SOCKETPAIR: {
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_SYMLINKAT: {},
syscall.SYS_TGKILL: []seccomp.Rule{
{
seccomp.AllowValue(uint64(os.Getpid())),
@@ -180,3 +215,16 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_UTIMENSAT: {},
syscall.SYS_WRITE: {},
}
+
+var udsSyscalls = seccomp.SyscallRules{
+ syscall.SYS_SOCKET: []seccomp.Rule{
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ },
+ },
+ syscall.SYS_CONNECT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ },
+ },
+}
diff --git a/runsc/fsgofer/filter/filter.go b/runsc/fsgofer/filter/filter.go
index 65053415f..289886720 100644
--- a/runsc/fsgofer/filter/filter.go
+++ b/runsc/fsgofer/filter/filter.go
@@ -23,11 +23,16 @@ import (
// Install installs seccomp filters.
func Install() error {
- s := allowedSyscalls
-
// Set of additional filters used by -race and -msan. Returns empty
// when not enabled.
- s.Merge(instrumentationFilters())
+ allowedSyscalls.Merge(instrumentationFilters())
+
+ return seccomp.Install(allowedSyscalls)
+}
- return seccomp.Install(s)
+// InstallUDSFilters extends the allowed syscalls to include those necessary for
+// connecting to a host UDS.
+func InstallUDSFilters() {
+ // Add additional filters required for connecting to the host's sockets.
+ allowedSyscalls.Merge(udsSyscalls)
}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 7c4d2b94e..ed8b02cf0 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -54,6 +54,7 @@ const (
regular fileType = iota
directory
symlink
+ socket
unknown
)
@@ -66,6 +67,8 @@ func (f fileType) String() string {
return "directory"
case symlink:
return "symlink"
+ case socket:
+ return "socket"
}
return "unknown"
}
@@ -82,6 +85,9 @@ type Config struct {
// PanicOnWrite panics on attempts to write to RO mounts.
PanicOnWrite bool
+
+ // HostUDS signals whether the gofer can mount a host's UDS.
+ HostUDS bool
}
type attachPoint struct {
@@ -119,35 +125,31 @@ func NewAttachPoint(prefix string, c Config) (p9.Attacher, error) {
// Attach implements p9.Attacher.
func (a *attachPoint) Attach() (p9.File, error) {
- // dirFD (1st argument) is ignored because 'prefix' is always absolute.
- stat, err := statAt(-1, a.prefix)
- if err != nil {
- return nil, fmt.Errorf("stat file %q, err: %v", a.prefix, err)
- }
- mode := syscall.O_RDWR
- if a.conf.ROMount || (stat.Mode&syscall.S_IFMT) == syscall.S_IFDIR {
- mode = syscall.O_RDONLY
+ a.attachedMu.Lock()
+ defer a.attachedMu.Unlock()
+
+ if a.attached {
+ return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
}
- // Open the root directory.
- f, err := fd.Open(a.prefix, openFlags|mode, 0)
+ f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) {
+ return fd.Open(a.prefix, openFlags|mode, 0)
+ })
if err != nil {
- return nil, fmt.Errorf("unable to open file %q, err: %v", a.prefix, err)
+ return nil, fmt.Errorf("unable to open %q: %v", a.prefix, err)
}
- a.attachedMu.Lock()
- defer a.attachedMu.Unlock()
- if a.attached {
- f.Close()
- return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
+ stat, err := stat(f.FD())
+ if err != nil {
+ return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err)
}
- rv, err := newLocalFile(a, f, a.prefix, stat)
+ lf, err := newLocalFile(a, f, a.prefix, stat)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err)
}
a.attached = true
- return rv, nil
+ return lf, nil
}
// makeQID returns a unique QID for the given stat buffer.
@@ -295,7 +297,7 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error)
return file, nil
}
-func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
+func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) {
var ft fileType
switch stat.Mode & syscall.S_IFMT {
case syscall.S_IFREG:
@@ -304,6 +306,11 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
ft = directory
case syscall.S_IFLNK:
ft = symlink
+ case syscall.S_IFSOCK:
+ if !permitSocket {
+ return unknown, syscall.EPERM
+ }
+ ft = socket
default:
return unknown, syscall.EPERM
}
@@ -311,7 +318,7 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
}
func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) {
- ft, err := getSupportedFileType(stat)
+ ft, err := getSupportedFileType(stat, a.conf.HostUDS)
if err != nil {
return nil, err
}
@@ -1026,7 +1033,11 @@ func (l *localFile) Flush() error {
// Connect implements p9.File.
func (l *localFile) Connect(p9.ConnectFlags) (*fd.FD, error) {
- return nil, syscall.ECONNREFUSED
+ // Check to see if the CLI option has been set to allow the UDS mount.
+ if !l.attachPoint.conf.HostUDS {
+ return nil, syscall.ECONNREFUSED
+ }
+ return fd.DialUnix(l.hostPath)
}
// Close implements p9.File.
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index c86beaef1..05af7e397 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -635,7 +635,15 @@ func TestAttachInvalidType(t *testing.T) {
t.Fatalf("Mkfifo(%q): %v", fifo, err)
}
- socket := filepath.Join(dir, "socket")
+ dirFile, err := os.Open(dir)
+ if err != nil {
+ t.Fatalf("Open(%s): %v", dir, err)
+ }
+ defer dirFile.Close()
+
+ // Bind a socket via /proc to be sure that a length of a socket path
+ // is less than UNIX_PATH_MAX.
+ socket := filepath.Join(fmt.Sprintf("/proc/self/fd/%d", dirFile.Fd()), "socket")
l, err := net.Listen("unix", socket)
if err != nil {
t.Fatalf("net.Listen(unix, %q): %v", socket, err)
@@ -657,7 +665,7 @@ func TestAttachInvalidType(t *testing.T) {
}
f, err := a.Attach()
if f != nil || err == nil {
- t.Fatalf("Attach should have failed, got (%v, nil)", f)
+ t.Fatalf("Attach should have failed, got (%v, %v)", f, err)
}
})
}
diff --git a/runsc/main.go b/runsc/main.go
index e864118b2..7dce9dc00 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -31,6 +31,7 @@ import (
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/cmd"
@@ -67,6 +68,7 @@ var (
network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
gso = flag.Bool("gso", true, "enable generic segmenation offload")
fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
+ fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "Allow the gofer to mount Unix Domain Sockets.")
overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.")
panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.")
@@ -74,9 +76,11 @@ var (
netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.")
numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.")
rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
+ referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.")
// Test flags, not to be used outside tests, ever.
testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
+ testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.")
)
func main() {
@@ -86,6 +90,11 @@ func main() {
subcommands.Register(help, "")
subcommands.Register(subcommands.FlagsCommand(), "")
+ // Installation helpers.
+ const helperGroup = "helpers"
+ subcommands.Register(new(cmd.Install), helperGroup)
+ subcommands.Register(new(cmd.Uninstall), helperGroup)
+
// Register user-facing runsc commands.
subcommands.Register(new(cmd.Checkpoint), "")
subcommands.Register(new(cmd.Create), "")
@@ -117,13 +126,6 @@ func main() {
// All subcommands must be registered before flag parsing.
flag.Parse()
- if *testOnlyAllowRunAsCurrentUserWithoutChroot {
- // SIGTERM is sent to all processes if a test exceeds its
- // timeout and this case is handled by syscall_test_runner.
- log.Warningf("Block the TERM signal. This is only safe in tests!")
- signal.Ignore(syscall.SIGTERM)
- }
-
// Are we showing the version?
if *showVersion {
// The format here is the same as runc.
@@ -176,6 +178,15 @@ func main() {
cmd.Fatalf("num_network_channels must be > 0, got: %d", *numNetworkChannels)
}
+ refsLeakMode, err := boot.MakeRefsLeakMode(*referenceLeakMode)
+ if err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ // Sets the reference leak check mode. Also set it in config below to
+ // propagate it to child processes.
+ refs.SetLeakMode(refsLeakMode)
+
// Create a new Config from the flags.
conf := &boot.Config{
RootDir: *rootDir,
@@ -185,6 +196,7 @@ func main() {
DebugLog: *debugLog,
DebugLogFormat: *debugLogFormat,
FileAccess: fsAccess,
+ FSGoferHostUDS: *fsGoferHostUDS,
Overlay: *overlay,
Network: netType,
GSO: *gso,
@@ -199,8 +211,10 @@ func main() {
NumNetworkChannels: *numNetworkChannels,
Rootless: *rootless,
AlsoLogToStderr: *alsoLogToStderr,
+ ReferenceLeakMode: refsLeakMode,
TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
+ TestOnlyTestNameEnv: *testOnlyTestNameEnv,
}
if len(*straceSyscalls) != 0 {
conf.StraceSyscalls = strings.Split(*straceSyscalls, ",")
@@ -227,14 +241,14 @@ func main() {
// want with them. Since Docker and Containerd both eat boot's stderr, we
// dup our stderr to the provided log FD so that panics will appear in the
// logs, rather than just disappear.
- if err := syscall.Dup2(int(f.Fd()), int(os.Stderr.Fd())); err != nil {
+ if err := syscall.Dup3(int(f.Fd()), int(os.Stderr.Fd()), 0); err != nil {
cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err)
}
e = newEmitter(*debugLogFormat, f)
} else if *debugLog != "" {
- f, err := specutils.DebugLogFile(*debugLog, subcommand)
+ f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */)
if err != nil {
cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err)
}
@@ -265,6 +279,13 @@ func main() {
log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
log.Infof("***************************")
+ if *testOnlyAllowRunAsCurrentUserWithoutChroot {
+ // SIGTERM is sent to all processes if a test exceeds its
+ // timeout and this case is handled by syscall_test_runner.
+ log.Warningf("Block the TERM signal. This is only safe in tests!")
+ signal.Ignore(syscall.SIGTERM)
+ }
+
// Call the subcommand and pass in the configuration.
var ws syscall.WaitStatus
subcmdCode := subcommands.Execute(context.Background(), conf, &ws)
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index df3c0c5ef..ee9327fc8 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -351,7 +351,15 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
if conf.DebugLog != "" {
- debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot")
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot", test)
if err != nil {
return fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
}
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
index fbfb8e2f8..fa58313a0 100644
--- a/runsc/specutils/BUILD
+++ b/runsc/specutils/BUILD
@@ -13,6 +13,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/kernel/auth",
"@com_github_cenkalti_backoff//:go_default_library",
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
index d441419cb..c7dd3051c 100644
--- a/runsc/specutils/namespace.go
+++ b/runsc/specutils/namespace.go
@@ -33,19 +33,19 @@ import (
func nsCloneFlag(nst specs.LinuxNamespaceType) uintptr {
switch nst {
case specs.IPCNamespace:
- return syscall.CLONE_NEWIPC
+ return unix.CLONE_NEWIPC
case specs.MountNamespace:
- return syscall.CLONE_NEWNS
+ return unix.CLONE_NEWNS
case specs.NetworkNamespace:
- return syscall.CLONE_NEWNET
+ return unix.CLONE_NEWNET
case specs.PIDNamespace:
- return syscall.CLONE_NEWPID
+ return unix.CLONE_NEWPID
case specs.UTSNamespace:
- return syscall.CLONE_NEWUTS
+ return unix.CLONE_NEWUTS
case specs.UserNamespace:
- return syscall.CLONE_NEWUSER
+ return unix.CLONE_NEWUSER
case specs.CgroupNamespace:
- panic("cgroup namespace has no associated clone flag")
+ return unix.CLONE_NEWCGROUP
default:
panic(fmt.Sprintf("unknown namespace %v", nst))
}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 2eec92349..3d9ced1b6 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -23,6 +23,7 @@ import (
"os"
"path"
"path/filepath"
+ "strconv"
"strings"
"syscall"
"time"
@@ -30,6 +31,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -90,7 +92,7 @@ func ValidateSpec(spec *specs.Spec) error {
log.Warningf("AppArmor profile %q is being ignored", spec.Process.ApparmorProfile)
}
- // TODO(b/72226747): Apply seccomp to application inside sandbox.
+ // TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox.
if spec.Linux != nil && spec.Linux.Seccomp != nil {
log.Warningf("Seccomp spec is being ignored")
}
@@ -240,6 +242,15 @@ func AllCapabilities() *specs.LinuxCapabilities {
}
}
+// AllCapabilitiesUint64 returns a bitmask containing all capabilities set.
+func AllCapabilitiesUint64() uint64 {
+ var rv uint64
+ for _, cap := range capFromName {
+ rv |= bits.MaskOf64(int(cap))
+ }
+ return rv
+}
+
var capFromName = map[string]linux.Capability{
"CAP_CHOWN": linux.CAP_CHOWN,
"CAP_DAC_OVERRIDE": linux.CAP_DAC_OVERRIDE,
@@ -398,13 +409,15 @@ func WaitForReady(pid int, timeout time.Duration, ready func() (bool, error)) er
// - %TIMESTAMP%: is replaced with a timestamp using the following format:
// <yyyymmdd-hhmmss.uuuuuu>
// - %COMMAND%: is replaced with 'command'
-func DebugLogFile(logPattern, command string) (*os.File, error) {
+// - %TEST%: is replaced with 'test' (omitted by default)
+func DebugLogFile(logPattern, command, test string) (*os.File, error) {
if strings.HasSuffix(logPattern, "/") {
// Default format: <debug-log>/runsc.log.<yyyymmdd-hhmmss.uuuuuu>.<command>
logPattern += "runsc.log.%TIMESTAMP%.%COMMAND%"
}
logPattern = strings.Replace(logPattern, "%TIMESTAMP%", time.Now().Format("20060102-150405.000000"), -1)
logPattern = strings.Replace(logPattern, "%COMMAND%", command, -1)
+ logPattern = strings.Replace(logPattern, "%TEST%", test, -1)
dir := filepath.Dir(logPattern)
if err := os.MkdirAll(dir, 0775); err != nil {
@@ -503,3 +516,53 @@ func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) {
}
}
}
+
+// GetOOMScoreAdj reads the given process' oom_score_adj
+func GetOOMScoreAdj(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid))
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(strings.TrimSpace(string(data)))
+}
+
+// GetParentPid gets the parent process ID of the specified PID.
+func GetParentPid(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid))
+ if err != nil {
+ return 0, err
+ }
+
+ var cpid string
+ var name string
+ var state string
+ var ppid int
+ // Parse after the binary name.
+ _, err = fmt.Sscanf(string(data),
+ "%v %v %v %d",
+ // cpid is ignored.
+ &cpid,
+ // name is ignored.
+ &name,
+ // state is ignored.
+ &state,
+ &ppid)
+
+ if err != nil {
+ return 0, err
+ }
+
+ return ppid, nil
+}
+
+// EnvVar looks for a varible value in the env slice assuming the following
+// format: "NAME=VALUE".
+func EnvVar(env []string, name string) (string, bool) {
+ prefix := name + "="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.TrimPrefix(e, prefix), true
+ }
+ }
+ return "", false
+}
diff --git a/runsc/test/BUILD b/runsc/test/BUILD
deleted file mode 100644
index e69de29bb..000000000
--- a/runsc/test/BUILD
+++ /dev/null
diff --git a/runsc/test/README.md b/runsc/test/README.md
deleted file mode 100644
index f22a8e017..000000000
--- a/runsc/test/README.md
+++ /dev/null
@@ -1,24 +0,0 @@
-# Tests
-
-The tests defined under this path are verifying functionality beyond what unit
-tests can cover, e.g. integration and end to end tests. Due to their nature,
-they may need extra setup in the test machine and extra configuration to run.
-
-- **integration:** defines integration tests that uses `docker run` to test
- functionality.
-- **image:** basic end to end test for popular images.
-- **root:** tests that require to be run as root.
-- **testutil:** utilities library to support the tests.
-
-The following setup steps are required in order to run these tests:
-
- `./runsc/test/install.sh [--runtime <name>]`
-
-The tests expect the runtime name to be provided in the `RUNSC_RUNTIME`
-environment variable (default: `runsc-test`). To run the tests execute:
-
-```
-bazel test --test_env=RUNSC_RUNTIME=runsc-test \
- //runsc/test/image:image_test \
- //runsc/test/integration:integration_test
-```
diff --git a/runsc/test/build_defs.bzl b/runsc/test/build_defs.bzl
deleted file mode 100644
index ac28cc037..000000000
--- a/runsc/test/build_defs.bzl
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Defines a rule for runsc test targets."""
-
-load("@io_bazel_rules_go//go:def.bzl", _go_test = "go_test")
-
-# runtime_test is a macro that will create targets to run the given test target
-# with different runtime options.
-def runtime_test(**kwargs):
- """Runs the given test target with different runtime options."""
- name = kwargs["name"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_hostnet"
- kwargs["args"] = ["--runtime-type=hostnet"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_kvm"
- kwargs["args"] = ["--runtime-type=kvm"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_overlay"
- kwargs["args"] = ["--runtime-type=overlay"]
- _go_test(**kwargs)
diff --git a/runsc/test/install.sh b/runsc/test/install.sh
deleted file mode 100755
index 8f05dea20..000000000
--- a/runsc/test/install.sh
+++ /dev/null
@@ -1,93 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Fail on any error
-set -e
-
-# Defaults
-declare runtime=runsc-test
-declare uninstall=0
-
-function findExe() {
- local exe=${1}
-
- local path=$(find bazel-bin/runsc -type f -executable -name "${exe}" | head -n1)
- if [[ "${path}" == "" ]]; then
- echo "Location of ${exe} not found in bazel-bin" >&2
- exit 1
- fi
- echo "${path}"
-}
-
-while [[ $# -gt 0 ]]; do
- case "$1" in
- --runtime)
- shift
- [ "$#" -le 0 ] && echo "No runtime provided" && exit 1
- runtime=$1
- ;;
- -u)
- uninstall=1
- ;;
- *)
- echo "Unknown option: ${1}"
- echo ""
- echo "Usage: ${0} [--runtime <name>] [-u]"
- echo " --runtime sets the runtime name, default: runsc-test"
- echo " -u uninstall the runtime"
- exit 1
- esac
- shift
-done
-
-# Find location of executables.
-declare -r dockercfg=$(findExe dockercfg)
-[[ "${dockercfg}" == "" ]] && exit 1
-
-declare runsc=$(findExe runsc)
-[[ "${runsc}" == "" ]] && exit 1
-
-if [[ ${uninstall} == 0 ]]; then
- rm -rf /tmp/${runtime}
- mkdir -p /tmp/${runtime}
- cp "${runsc}" /tmp/${runtime}/runsc
- runsc=/tmp/${runtime}/runsc
-
- # Make tmp dir and runsc binary readable and executable to all users, since it
- # will run in an empty user namespace.
- chmod a+rx "${runsc}" $(dirname "${runsc}")
-
- # Make log dir executable and writable to all users for the same reason.
- declare logdir=/tmp/"${runtime?}/logs"
- mkdir -p "${logdir}"
- sudo -n chmod a+wx "${logdir}"
-
- declare -r args="--debug-log '${logdir}/' --debug --strace --log-packets"
- # experimental is needed to checkpoint/restore.
- sudo -n "${dockercfg}" --experimental=true runtime-add "${runtime}" "${runsc}" ${args}
- sudo -n "${dockercfg}" runtime-add "${runtime}"-kvm "${runsc}" --platform=kvm ${args}
- sudo -n "${dockercfg}" runtime-add "${runtime}"-hostnet "${runsc}" --network=host ${args}
- sudo -n "${dockercfg}" runtime-add "${runtime}"-overlay "${runsc}" --overlay ${args}
-
-else
- sudo -n "${dockercfg}" runtime-rm "${runtime}"
- sudo -n "${dockercfg}" runtime-rm "${runtime}"-kvm
- sudo -n "${dockercfg}" runtime-rm "${runtime}"-hostnet
- sudo -n "${dockercfg}" runtime-rm "${runtime}"-overlay
-fi
-
-echo "Restarting docker service..."
-sudo -n /etc/init.d/docker restart
diff --git a/runsc/test/integration/exec_test.go b/runsc/test/integration/exec_test.go
deleted file mode 100644
index 993136f96..000000000
--- a/runsc/test/integration/exec_test.go
+++ /dev/null
@@ -1,161 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package image provides end-to-end integration tests for runsc. These tests require
-// docker and runsc to be installed on the machine. To set it up, run:
-//
-// ./runsc/test/install.sh [--runtime <name>]
-//
-// The tests expect the runtime name to be provided in the RUNSC_RUNTIME
-// environment variable (default: runsc-test).
-//
-// Each test calls docker commands to start up a container, and tests that it is
-// behaving properly, with various runsc commands. The container is killed and deleted
-// at the end.
-
-package integration
-
-import (
- "fmt"
- "strconv"
- "strings"
- "syscall"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/runsc/test/testutil"
-)
-
-func TestExecCapabilities(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := testutil.MakeDocker("exec-test")
-
- // Start the container.
- if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- defer d.CleanUp()
-
- matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
- if err != nil {
- t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
- }
- if len(matches) != 2 {
- t.Fatalf("There should be a match for the whole line and the capability bitmask")
- }
- capString := matches[1]
- t.Log("Root capabilities:", capString)
-
- // CAP_NET_RAW was in the capability set for the container, but was
- // removed. However, `exec` does not remove it. Verify that it's not
- // set in the container, then re-add it for comparison.
- caps, err := strconv.ParseUint(capString, 16, 64)
- if err != nil {
- t.Fatalf("failed to convert capabilities %q: %v", capString, err)
- }
- if caps&(1<<uint64(linux.CAP_NET_RAW)) != 0 {
- t.Fatalf("CAP_NET_RAW should be filtered, but is set in the container: %x", caps)
- }
- caps |= 1 << uint64(linux.CAP_NET_RAW)
- want := fmt.Sprintf("CapEff:\t%016x\n", caps)
-
- // Now check that exec'd process capabilities match the root.
- got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
- if err != nil {
- t.Fatalf("docker exec failed: %v", err)
- }
- if got != want {
- t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
- }
-}
-
-func TestExecJobControl(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := testutil.MakeDocker("exec-job-control-test")
-
- // Start the container.
- if err := d.Run("alpine", "sleep", "1000"); err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- defer d.CleanUp()
-
- // Exec 'sh' with an attached pty.
- cmd, ptmx, err := d.ExecWithTerminal("sh")
- if err != nil {
- t.Fatalf("docker exec failed: %v", err)
- }
- defer ptmx.Close()
-
- // Call "sleep 100 | cat" in the shell. We pipe to cat so that there
- // will be two processes in the foreground process group.
- if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
- }
-
- // Give shell a few seconds to start executing the sleep.
- time.Sleep(2 * time.Second)
-
- // Send a ^C to the pty, which should kill sleep and cat, but not the
- // shell. \x03 is ASCII "end of text", which is the same as ^C.
- if _, err := ptmx.Write([]byte{'\x03'}); err != nil {
- t.Fatalf("error writing to pty: %v", err)
- }
-
- // The shell should still be alive at this point. Sleep should have
- // exited with code 2+128=130. We'll exit with 10 plus that number, so
- // that we can be sure that the shell did not get signalled.
- if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
- }
-
- // Exec process should exit with code 10+130=140.
- ps, err := cmd.Process.Wait()
- if err != nil {
- t.Fatalf("error waiting for exec process: %v", err)
- }
- ws := ps.Sys().(syscall.WaitStatus)
- if !ws.Exited() {
- t.Errorf("ws.Exited got false, want true")
- }
- if got, want := ws.ExitStatus(), 140; got != want {
- t.Errorf("ws.ExitedStatus got %d, want %d", got, want)
- }
-}
-
-// Test that failure to exec returns proper error message.
-func TestExecError(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := testutil.MakeDocker("exec-error-test")
-
- // Start the container.
- if err := d.Run("alpine", "sleep", "1000"); err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- defer d.CleanUp()
-
- _, err := d.Exec("no_can_find")
- if err == nil {
- t.Fatalf("docker exec didn't fail")
- }
- if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) {
- t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want)
- }
-}
diff --git a/runsc/test/testutil/BUILD b/runsc/testutil/BUILD
index 327e7ca4d..d44ebc906 100644
--- a/runsc/test/testutil/BUILD
+++ b/runsc/testutil/BUILD
@@ -4,19 +4,14 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
- srcs = [
- "crictl.go",
- "docker.go",
- "testutil.go",
- "testutil_race.go",
- ],
- importpath = "gvisor.dev/gvisor/runsc/test/testutil",
+ testonly = 1,
+ srcs = ["testutil.go"],
+ importpath = "gvisor.dev/gvisor/runsc/testutil",
visibility = ["//:sandbox"],
deps = [
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_kr_pty//:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
],
)
diff --git a/runsc/test/testutil/testutil.go b/runsc/testutil/testutil.go
index e288c7758..edf8b126c 100644
--- a/runsc/test/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -18,18 +18,22 @@ package testutil
import (
"bufio"
"context"
+ "debug/elf"
"encoding/base32"
"encoding/json"
+ "flag"
"fmt"
"io"
"io/ioutil"
"log"
+ "math"
"math/rand"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
+ "strconv"
"strings"
"sync"
"sync/atomic"
@@ -42,12 +46,18 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
)
+var (
+ checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+)
+
func init() {
rand.Seed(time.Now().UnixNano())
}
-// RaceEnabled is set to true if it was built with '--race' option.
-var RaceEnabled = false
+// IsCheckpointSupported returns the relevant command line flag.
+func IsCheckpointSupported() bool {
+ return *checkpoint
+}
// TmpDir returns the absolute path to a writable directory that can be used as
// scratch by the test.
@@ -127,13 +137,15 @@ func FindFile(path string) (string, error) {
// 'RootDir' must be set by caller if required.
func TestConfig() *boot.Config {
return &boot.Config{
- Debug: true,
- LogFormat: "text",
- LogPackets: true,
- Network: boot.NetworkNone,
- Strace: true,
- Platform: "ptrace",
- FileAccess: boot.FileAccessExclusive,
+ Debug: true,
+ LogFormat: "text",
+ DebugLogFormat: "text",
+ AlsoLogToStderr: true,
+ LogPackets: true,
+ Network: boot.NetworkNone,
+ Strace: true,
+ Platform: "ptrace",
+ FileAccess: boot.FileAccessExclusive,
TestOnlyAllowRunAsCurrentUserWithoutChroot: true,
NumNetworkChannels: 1,
}
@@ -189,14 +201,11 @@ func SetupRootDir() (string, error) {
// SetupContainer creates a bundle and root dir for the container, generates a
// test config, and writes the spec to config.json in the bundle dir.
func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, err error) {
- // Setup root dir if one hasn't been provided.
- if len(conf.RootDir) == 0 {
- rootDir, err = SetupRootDir()
- if err != nil {
- return "", "", err
- }
- conf.RootDir = rootDir
+ rootDir, err = SetupRootDir()
+ if err != nil {
+ return "", "", err
}
+ conf.RootDir = rootDir
bundleDir, err = SetupBundleDir(spec)
return rootDir, bundleDir, err
}
@@ -417,3 +426,58 @@ func WriteTmpFile(pattern, text string) (string, error) {
func RandomName(prefix string) string {
return fmt.Sprintf("%s-%06d", prefix, rand.Int31n(1000000))
}
+
+// IsStatic returns true iff the given file is a static binary.
+func IsStatic(filename string) (bool, error) {
+ f, err := elf.Open(filename)
+ if err != nil {
+ return false, err
+ }
+ for _, prog := range f.Progs {
+ if prog.Type == elf.PT_INTERP {
+ return false, nil // Has interpreter.
+ }
+ }
+ return true, nil
+}
+
+// TestBoundsForShard calculates the beginning and end indices for the test
+// based on the TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. The
+// returned ints are the beginning (inclusive) and end (exclusive) of the
+// subslice corresponding to the shard. If either of the env vars are not
+// present, then the function will return bounds that include all tests. If
+// there are more shards than there are tests, then the returned list may be
+// empty.
+func TestBoundsForShard(numTests int) (int, int, error) {
+ var (
+ begin = 0
+ end = numTests
+ )
+ indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
+ if indexStr == "" || totalStr == "" {
+ return begin, end, nil
+ }
+
+ // Parse index and total to ints.
+ shardIndex, err := strconv.Atoi(indexStr)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
+ }
+ shardTotal, err := strconv.Atoi(totalStr)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ }
+
+ // Calculate!
+ shardSize := int(math.Ceil(float64(numTests) / float64(shardTotal)))
+ begin = shardIndex * shardSize
+ end = ((shardIndex + 1) * shardSize)
+ if begin > numTests {
+ // Nothing to run.
+ return 0, 0, nil
+ }
+ if end > numTests {
+ end = numTests
+ }
+ return begin, end, nil
+}
diff --git a/runsc/tools/dockercfg/BUILD b/runsc/tools/dockercfg/BUILD
deleted file mode 100644
index 5cff917ed..000000000
--- a/runsc/tools/dockercfg/BUILD
+++ /dev/null
@@ -1,10 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "dockercfg",
- srcs = ["dockercfg.go"],
- visibility = ["//visibility:public"],
- deps = ["@com_github_google_subcommands//:go_default_library"],
-)
diff --git a/runsc/tools/dockercfg/dockercfg.go b/runsc/tools/dockercfg/dockercfg.go
deleted file mode 100644
index eb9dbd421..000000000
--- a/runsc/tools/dockercfg/dockercfg.go
+++ /dev/null
@@ -1,193 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Helper tool to configure Docker daemon.
-package main
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io/ioutil"
- "log"
- "os"
-
- "flag"
- "github.com/google/subcommands"
-)
-
-var (
- configFile = flag.String("config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
- experimental = flag.Bool("experimental", false, "enable experimental features")
-)
-
-func main() {
- subcommands.Register(subcommands.HelpCommand(), "")
- subcommands.Register(subcommands.FlagsCommand(), "")
- subcommands.Register(&runtimeAdd{}, "")
- subcommands.Register(&runtimeRemove{}, "")
-
- // All subcommands must be registered before flag parsing.
- flag.Parse()
-
- exitCode := subcommands.Execute(context.Background())
- os.Exit(int(exitCode))
-}
-
-type runtime struct {
- Path string `json:"path,omitempty"`
- RuntimeArgs []string `json:"runtimeArgs,omitempty"`
-}
-
-// runtimeAdd implements subcommands.Command.
-type runtimeAdd struct {
-}
-
-// Name implements subcommands.Command.Name.
-func (*runtimeAdd) Name() string {
- return "runtime-add"
-}
-
-// Synopsis implements subcommands.Command.Synopsis.
-func (*runtimeAdd) Synopsis() string {
- return "adds a runtime to docker daemon configuration"
-}
-
-// Usage implements subcommands.Command.Usage.
-func (*runtimeAdd) Usage() string {
- return `runtime-add [flags] <name> <path> [args...] -- if provided, args are passed as arguments to the runtime
-`
-}
-
-// SetFlags implements subcommands.Command.SetFlags.
-func (*runtimeAdd) SetFlags(*flag.FlagSet) {
-}
-
-// Execute implements subcommands.Command.Execute.
-func (r *runtimeAdd) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- if f.NArg() < 2 {
- f.Usage()
- return subcommands.ExitUsageError
- }
- name := f.Arg(0)
- path := f.Arg(1)
- runtimeArgs := f.Args()[2:]
-
- fmt.Printf("Adding runtime %q to file %q\n", name, *configFile)
- c, err := readConfig(*configFile)
- if err != nil {
- log.Fatalf("Error reading config file %q: %v", *configFile, err)
- }
-
- var rts map[string]interface{}
- if i, ok := c["runtimes"]; ok {
- rts = i.(map[string]interface{})
- } else {
- rts = make(map[string]interface{})
- c["runtimes"] = rts
- }
- if *experimental {
- c["experimental"] = true
- }
- rts[name] = runtime{Path: path, RuntimeArgs: runtimeArgs}
-
- if err := writeConfig(c, *configFile); err != nil {
- log.Fatalf("Error writing config file %q: %v", *configFile, err)
- }
- return subcommands.ExitSuccess
-}
-
-// runtimeRemove implements subcommands.Command.
-type runtimeRemove struct {
-}
-
-// Name implements subcommands.Command.Name.
-func (*runtimeRemove) Name() string {
- return "runtime-rm"
-}
-
-// Synopsis implements subcommands.Command.Synopsis.
-func (*runtimeRemove) Synopsis() string {
- return "removes a runtime from docker daemon configuration"
-}
-
-// Usage implements subcommands.Command.Usage.
-func (*runtimeRemove) Usage() string {
- return `runtime-rm [flags] <name>
-`
-}
-
-// SetFlags implements subcommands.Command.SetFlags.
-func (*runtimeRemove) SetFlags(*flag.FlagSet) {
-}
-
-// Execute implements subcommands.Command.Execute.
-func (r *runtimeRemove) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- if f.NArg() != 1 {
- f.Usage()
- return subcommands.ExitUsageError
- }
- name := f.Arg(0)
-
- fmt.Printf("Removing runtime %q from file %q\n", name, *configFile)
- c, err := readConfig(*configFile)
- if err != nil {
- log.Fatalf("Error reading config file %q: %v", *configFile, err)
- }
-
- var rts map[string]interface{}
- if i, ok := c["runtimes"]; ok {
- rts = i.(map[string]interface{})
- } else {
- log.Fatalf("runtime %q not found", name)
- }
- if _, ok := rts[name]; !ok {
- log.Fatalf("runtime %q not found", name)
- }
- delete(rts, name)
-
- if err := writeConfig(c, *configFile); err != nil {
- log.Fatalf("Error writing config file %q: %v", *configFile, err)
- }
- return subcommands.ExitSuccess
-}
-
-func readConfig(path string) (map[string]interface{}, error) {
- configBytes, err := ioutil.ReadFile(path)
- if err != nil && !os.IsNotExist(err) {
- return nil, err
- }
- c := make(map[string]interface{})
- if len(configBytes) > 0 {
- if err := json.Unmarshal(configBytes, &c); err != nil {
- return nil, err
- }
- }
- return c, nil
-}
-
-func writeConfig(c map[string]interface{}, path string) error {
- b, err := json.MarshalIndent(c, "", " ")
- if err != nil {
- return err
- }
-
- if err := os.Rename(path, path+"~"); err != nil && !os.IsNotExist(err) {
- return fmt.Errorf("error renaming config file %q: %v", path, err)
- }
- if err := ioutil.WriteFile(path, b, 0644); err != nil {
- return fmt.Errorf("error writing config file %q: %v", path, err)
- }
- return nil
-}
diff --git a/runsc/version.go b/runsc/version.go
index ce0573a9b..ab9194b9d 100644
--- a/runsc/version.go
+++ b/runsc/version.go
@@ -15,4 +15,4 @@
package main
// version is set during linking.
-var version = ""
+var version = "VERSION_MISSING"
diff --git a/runsc/version_test.sh b/runsc/version_test.sh
new file mode 100755
index 000000000..cc0ca3f05
--- /dev/null
+++ b/runsc/version_test.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euf -x -o pipefail
+
+readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc"
+readonly version=$($runsc --version)
+
+# Version should should not match VERSION, which is the default and which will
+# also appear if something is wrong with workspace_status.sh script.
+if [[ $version =~ "VERSION" ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+# Version should contain at least one number.
+if [[ ! $version =~ [0-9] ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+echo "PASS: Got OK version $version"
+exit 0
diff --git a/scripts/build.sh b/scripts/build.sh
new file mode 100755
index 000000000..0b3d1b316
--- /dev/null
+++ b/scripts/build.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Install required packages for make_repository.sh et al.
+sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils
+
+# Build runsc.
+runsc=$(build -c opt //runsc)
+
+# Build packages.
+pkg=$(build -c opt //runsc:runsc-debian)
+
+# Build a repository, if the key is available.
+if [[ -v KOKORO_REPO_KEY ]]; then
+ repo=$(tools/make_repository.sh "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" gvisor-bot@google.com main ${pkg})
+fi
+
+# Install installs artifacts.
+install() {
+ local -r binaries_dir="$1"
+ local -r repo_dir="$2"
+ mkdir -p "${binaries_dir}"
+ cp -f "${runsc}" "${binaries_dir}"/runsc
+ sha512sum "${binaries_dir}"/runsc | awk '{print $1 " runsc"}' > "${binaries_dir}"/runsc.sha512
+ if [[ -v repo ]]; then
+ rm -rf "${repo_dir}" && mkdir -p "$(dirname "${repo_dir}")"
+ cp -a "${repo}" "${repo_dir}"
+ fi
+}
+
+# Move the runsc binary into "latest" directory, and also a directory with the
+# current date. If the current commit happens to correpond to a tag, then we
+# will also move everything into a directory named after the given tag.
+if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
+ if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then
+ # The "latest" directory and current date.
+ stamp="$(date -Idate)"
+ install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/nightly/latest"
+ install "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/nightly/${stamp}"
+ else
+ # Is it a tagged release? Build that instead. In that case, we also try to
+ # update the base release directory, in case this is an update. Finally, we
+ # update the "release" directory, which has the last released version.
+ tags="$(git tag --points-at HEAD)"
+ if ! [[ -z "${tags}" ]]; then
+ # Note that a given commit can match any number of tags. We have to
+ # iterate through all possible tags and produce associated artifacts.
+ for tag in ${tags}; do
+ name=$(echo "${tag}" | cut -d'-' -f2)
+ base=$(echo "${name}" | cut -d'.' -f1)
+ install "${KOKORO_ARTIFACTS_DIR}/release/${name}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/${name}"
+ if [[ "${base}" != "${tag}" ]]; then
+ install "${KOKORO_ARTIFACTS_DIR}/release/${base}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/${base}"
+ fi
+ install "${KOKORO_ARTIFACTS_DIR}/release/latest" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/latest"
+ done
+ fi
+ fi
+fi
diff --git a/scripts/common.sh b/scripts/common.sh
new file mode 100755
index 000000000..6dabad141
--- /dev/null
+++ b/scripts/common.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeou pipefail
+
+if [[ -f $(dirname $0)/common_google.sh ]]; then
+ source $(dirname $0)/common_google.sh
+else
+ source $(dirname $0)/common_bazel.sh
+fi
+
+# Ensure it attempts to collect logs in all cases.
+trap collect_logs EXIT
+
+function set_runtime() {
+ RUNTIME=${1:-runsc}
+ RUNSC_BIN=/tmp/"${RUNTIME}"/runsc
+ RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs
+ RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+}
+
+function test_runsc() {
+ test --test_arg=--runtime=${RUNTIME} "$@"
+}
+
+function install_runsc_for_test() {
+ local -r test_name=$1
+ shift
+ if [[ -z "${test_name}" ]]; then
+ echo "Missing mandatory test name"
+ exit 1
+ fi
+
+ # Add test to the name, so it doesn't conflict with other runtimes.
+ set_runtime $(find_branch_name)_"${test_name}"
+
+ # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name
+ # down to the runtime.
+ install_runsc "${RUNTIME}" \
+ --TESTONLY-test-name-env=RUNSC_TEST_NAME \
+ --debug \
+ --strace \
+ --log-packets \
+ "$@"
+}
+
+# Installs the runsc with given runtime name. set_runtime must have been called
+# to set runtime and logs location.
+function install_runsc() {
+ local -r runtime=$1
+ shift
+
+ # Prepare the runtime binary.
+ local -r output=$(build //runsc)
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f "${output}" "${RUNSC_BIN}"
+ chmod 0755 "${RUNSC_BIN}"
+
+ # Install the runtime.
+ sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
+
+ # Clear old logs files that may exist.
+ sudo rm -f "${RUNSC_LOGS_DIR}"/*
+
+ # Restart docker to pick up the new runtime configuration.
+ sudo systemctl restart docker
+}
diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh
new file mode 100755
index 000000000..f8ec967b1
--- /dev/null
+++ b/scripts/common_bazel.sh
@@ -0,0 +1,99 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Install the latest version of Bazel and log the version.
+(which use_bazel.sh && use_bazel.sh latest) || which bazel
+bazel version
+
+# Switch into the workspace; only necessary if run with kokoro.
+if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
+ cd git/repo
+elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
+ cd github/repo
+fi
+
+# Set the standard bazel flags.
+declare -r BAZEL_FLAGS=(
+ "--show_timestamps"
+ "--test_output=errors"
+ "--keep_going"
+ "--verbose_failures=true"
+)
+if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]] || [[ -v RBE_PROJECT_ID ]]; then
+ declare -r RBE_PROJECT_ID="${RBE_PROJECT_ID:-gvisor-rbe}"
+ declare -r BAZEL_RBE_FLAGS=(
+ "--config=remote"
+ "--project_id=${RBE_PROJECT_ID}"
+ "--remote_instance_name=projects/${RBE_PROJECT_ID}/instances/default_instance"
+ )
+fi
+if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
+ declare -r BAZEL_RBE_AUTH_FLAGS=(
+ "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
+ )
+fi
+
+# Wrap bazel.
+function build() {
+ bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 |
+ tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }'
+}
+
+function test() {
+ bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@"
+}
+
+function run() {
+ local binary=$1
+ shift
+ bazel run "${binary}" -- "$@"
+}
+
+function run_as_root() {
+ local binary=$1
+ shift
+ bazel run --run_under="sudo" "${binary}" -- "$@"
+}
+
+function collect_logs() {
+ # Zip out everything into a convenient form.
+ if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then
+ # Move test logs to Kokoro directory. tar is used to conveniently perform
+ # renames while moving files.
+ find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
+ tar --create --files-from - --transform 's/test\./sponge_log./' |
+ tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
+
+ # Collect sentry logs, if any.
+ if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then
+ # Check if the directory is empty or not (only the first line it needed).
+ local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1)
+ if [[ "${logs}" ]]; then
+ local -r archive=runsc_logs_"${RUNTIME}".tar.gz
+ if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then
+ echo "runsc logs will be uploaded to:"
+ echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
+ echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
+ fi
+ tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" .
+ fi
+ fi
+ fi
+}
+
+function find_branch_name() {
+ git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename
+}
diff --git a/scripts/dev.sh b/scripts/dev.sh
new file mode 100755
index 000000000..c67003018
--- /dev/null
+++ b/scripts/dev.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# common.sh sets '-x', but it's annoying to see so much output.
+set +x
+
+# Defaults
+declare -i REFRESH=0
+declare NAME=$(find_branch_name)
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --refresh)
+ REFRESH=1
+ ;;
+ --help)
+ echo "Use this script to build and install runsc with Docker."
+ echo
+ echo "usage: $0 [--refresh] [runtime_name]"
+ exit 1
+ ;;
+ *)
+ NAME=$1
+ ;;
+ esac
+ shift
+done
+
+set_runtime "${NAME}"
+echo
+echo "Using runtime=${RUNTIME}"
+echo
+
+echo Building runsc...
+# Build first and fail on error. $() prevents "set -e" from reporting errors.
+build //runsc
+declare OUTPUT="$(build //runsc)"
+
+if [[ ${REFRESH} -eq 0 ]]; then
+ install_runsc "${RUNTIME}" --net-raw
+ install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets
+
+ echo
+ echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup."
+ echo "Use --runtime="${RUNTIME}" with your Docker command."
+ echo " docker run --rm --runtime="${RUNTIME}" hello-world"
+ echo
+ echo "If you rebuild, use $0 --refresh."
+
+else
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f ${OUTPUT} "${RUNSC_BIN}"
+
+ echo
+ echo "Runtime ${RUNTIME} refreshed."
+fi
+
+echo "Logs are in: ${RUNSC_LOGS_DIR}"
diff --git a/scripts/do_tests.sh b/scripts/do_tests.sh
new file mode 100755
index 000000000..a3a387c37
--- /dev/null
+++ b/scripts/do_tests.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Build runsc.
+build //runsc
+
+# run runsc do without root privileges.
+run //runsc --rootless do true
+run //runsc --rootless --network=none do true
+
+# run runsc do with root privileges.
+run_as_root //runsc do true
diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh
new file mode 100755
index 000000000..72ba05260
--- /dev/null
+++ b/scripts/docker_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+install_runsc_for_test docker
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/go.sh b/scripts/go.sh
new file mode 100755
index 000000000..0dbfb7747
--- /dev/null
+++ b/scripts/go.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Build the go path.
+build :gopath
+
+# Build the synthetic branch.
+tools/go_branch.sh
+
+# Checkout the new branch.
+git checkout go && git clean -f
+
+# Build everything.
+go build ./...
+
+# Push, if required.
+if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then
+ if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ git config --global credential.helper cache
+ git credential approve <<EOF
+protocol=https
+host=github.com
+username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}")
+password=x-oauth-basic
+EOF
+ fi
+ git push origin go:go
+fi
diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh
new file mode 100755
index 000000000..41298293d
--- /dev/null
+++ b/scripts/hostnet_tests.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test hostnet --network=host
+test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh
new file mode 100755
index 000000000..5662401df
--- /dev/null
+++ b/scripts/kvm_tests.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Ensure that KVM is loaded, and we can use it.
+(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
+sudo chmod a+rw /dev/kvm
+
+# Run all KVM platform tests (locally).
+run_as_root //pkg/sentry/platform/kvm:kvm_test
+
+# Install the KVM runtime and run all integration tests.
+install_runsc_for_test kvm --platform=kvm
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/make_tests.sh b/scripts/make_tests.sh
new file mode 100755
index 000000000..79426756d
--- /dev/null
+++ b/scripts/make_tests.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+top_level=$(git rev-parse --show-toplevel 2>/dev/null)
+[[ $? -eq 0 ]] && cd "${top_level}" || exit 1
+
+make
+make runsc
+make BAZEL_OPTIONS="build //..." bazel
+make bazel-shutdown
diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh
new file mode 100755
index 000000000..2a1f12c0b
--- /dev/null
+++ b/scripts/overlay_tests.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test overlay --overlay
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/release.sh b/scripts/release.sh
new file mode 100755
index 000000000..b936bcc77
--- /dev/null
+++ b/scripts/release.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Tag a release only if provided.
+if ! [[ -v KOKORO_RELEASE_COMMIT ]]; then
+ echo "No KOKORO_RELEASE_COMMIT provided." >&2
+ exit 1
+fi
+if ! [[ -v KOKORO_RELEASE_TAG ]]; then
+ echo "No KOKORO_RELEASE_TAG provided." >&2
+ exit 1
+fi
+
+# Unless an explicit releaser is provided, use the bot e-mail.
+declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot}
+declare -r EMAIL=${EMAIL:-${KOKORO_RELEASE_AUTHOR}@google.com}
+
+# Ensure we have an appropriate configuration for the tag.
+git config --get user.name || git config user.name "gVisor-bot"
+git config --get user.email || git config user.email "${EMAIL}"
+
+# Run the release tool, which pushes to the origin repository.
+tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}"
diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh
new file mode 100755
index 000000000..4e4fcc76b
--- /dev/null
+++ b/scripts/root_tests.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Reinstall the latest containerd shim.
+declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
+declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
+declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
+wget --no-verbose "${base}"/latest -O ${latest}
+wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
+chmod +x ${shim_path}
+sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim
+
+# Run the tests that require root.
+install_runsc_for_test root
+run_as_root //test/root:root_test --runtime=${RUNTIME}
+
diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh
new file mode 100755
index 000000000..585216aae
--- /dev/null
+++ b/scripts/simple_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all simple tests (locally).
+test //pkg/... //runsc/... //tools/...
diff --git a/scripts/syscall_tests.sh b/scripts/syscall_tests.sh
new file mode 100755
index 000000000..a131b2d50
--- /dev/null
+++ b/scripts/syscall_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all ptrace-variants of the system call tests.
+test --test_tag_filters=runsc_ptrace //test/syscalls/...
diff --git a/test/README.md b/test/README.md
new file mode 100644
index 000000000..09c36b461
--- /dev/null
+++ b/test/README.md
@@ -0,0 +1,18 @@
+# Tests
+
+The tests defined under this path are verifying functionality beyond what unit
+tests can cover, e.g. integration and end to end tests. Due to their nature,
+they may need extra setup in the test machine and extra configuration to run.
+
+- **syscalls**: system call tests use a local runner, and do not require
+ additional configuration in the machine.
+- **integration:** defines integration tests that uses `docker run` to test
+ functionality.
+- **image:** basic end to end test for popular images. These require the same
+ setup as integration tests.
+- **root:** tests that require to be run as root.
+- **util:** utilities library to support the tests.
+
+For the above noted cases, the relevant runtime must be installed via `runsc
+install` before running. This is handled automatically by the test scripts in
+the `kokoro` directory.
diff --git a/runsc/test/integration/BUILD b/test/e2e/BUILD
index 12065617c..4fe03a220 100644
--- a/runsc/test/integration/BUILD
+++ b/test/e2e/BUILD
@@ -1,9 +1,8 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-load("//runsc/test:build_defs.bzl", "runtime_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(licenses = ["notice"])
-runtime_test(
+go_test(
name = "integration_test",
size = "large",
srcs = [
@@ -17,14 +16,18 @@ runtime_test(
"manual",
"local",
],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
- "//runsc/test/testutil",
+ "//pkg/bits",
+ "//runsc/dockerutil",
+ "//runsc/specutils",
+ "//runsc/testutil",
],
)
go_library(
name = "integration",
srcs = ["integration.go"],
- importpath = "gvisor.dev/gvisor/runsc/test/integration",
+ importpath = "gvisor.dev/gvisor/test/integration",
)
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
new file mode 100644
index 000000000..c962a3159
--- /dev/null
+++ b/test/e2e/exec_test.go
@@ -0,0 +1,275 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package integration provides end-to-end integration tests for runsc. These
+// tests require docker and runsc to be installed on the machine.
+//
+// Each test calls docker commands to start up a container, and tests that it
+// is behaving properly, with various runsc commands. The container is killed
+// and deleted at the end.
+
+package integration
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Test that exec uses the exact same capability set as the container.
+func TestExecCapabilities(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-capabilities-test")
+
+ // Start the container.
+ if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ want := fmt.Sprintf("CapEff:\t%s\n", matches[1])
+ t.Log("Root capabilities:", want)
+
+ // Now check that exec'd process capabilities match the root.
+ got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ t.Logf("CapEff: %v", got)
+ if got != want {
+ t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
+ }
+}
+
+// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW
+// which is removed from the container when --net-raw=false.
+func TestExecPrivileged(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-privileged-test")
+
+ // Start the container with all capabilities dropped.
+ if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Check that all capabilities where dropped from container.
+ matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ containerCaps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ t.Fatalf("failed to convert capabilities %q: %v", matches[1], err)
+ }
+ t.Logf("Container capabilities: %#x", containerCaps)
+ if containerCaps != 0 {
+ t.Fatalf("Container should have no capabilities: %x", containerCaps)
+ }
+
+ // Check that 'exec --privileged' adds all capabilities, except
+ // for CAP_NET_RAW.
+ got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ t.Logf("Exec CapEff: %v", got)
+ want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW)))
+ if got != want {
+ t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
+ }
+}
+
+func TestExecJobControl(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-job-control-test")
+
+ // Start the container.
+ if err := d.Run("alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec 'sh' with an attached pty.
+ cmd, ptmx, err := d.ExecWithTerminal("sh")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ defer ptmx.Close()
+
+ // Call "sleep 100 | cat" in the shell. We pipe to cat so that there
+ // will be two processes in the foreground process group.
+ if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil {
+ t.Fatalf("error writing to pty: %v", err)
+ }
+
+ // Give shell a few seconds to start executing the sleep.
+ time.Sleep(2 * time.Second)
+
+ // Send a ^C to the pty, which should kill sleep and cat, but not the
+ // shell. \x03 is ASCII "end of text", which is the same as ^C.
+ if _, err := ptmx.Write([]byte{'\x03'}); err != nil {
+ t.Fatalf("error writing to pty: %v", err)
+ }
+
+ // The shell should still be alive at this point. Sleep should have
+ // exited with code 2+128=130. We'll exit with 10 plus that number, so
+ // that we can be sure that the shell did not get signalled.
+ if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil {
+ t.Fatalf("error writing to pty: %v", err)
+ }
+
+ // Exec process should exit with code 10+130=140.
+ ps, err := cmd.Process.Wait()
+ if err != nil {
+ t.Fatalf("error waiting for exec process: %v", err)
+ }
+ ws := ps.Sys().(syscall.WaitStatus)
+ if !ws.Exited() {
+ t.Errorf("ws.Exited got false, want true")
+ }
+ if got, want := ws.ExitStatus(), 140; got != want {
+ t.Errorf("ws.ExitedStatus got %d, want %d", got, want)
+ }
+}
+
+// Test that failure to exec returns proper error message.
+func TestExecError(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-error-test")
+
+ // Start the container.
+ if err := d.Run("alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ _, err := d.Exec("no_can_find")
+ if err == nil {
+ t.Fatalf("docker exec didn't fail")
+ }
+ if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) {
+ t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want)
+ }
+}
+
+// Test that exec inherits environment from run.
+func TestExecEnv(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-test")
+
+ // Start the container with env FOO=BAR.
+ if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $FOO".
+ got, err := d.Exec("/bin/sh", "-c", "echo $FOO")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if got, want := strings.TrimSpace(got), "BAR"; got != want {
+ t.Errorf("bad output from 'docker exec'. Got %q; Want %q.", got, want)
+ }
+}
+
+// TestRunEnvHasHome tests that run always has HOME environment set.
+func TestRunEnvHasHome(t *testing.T) {
+ // Base alpine image does not have any environment variables set.
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("run-env-test")
+
+ // Exec "echo $HOME". The 'bin' user's home dir is '/bin'.
+ got, err := d.RunFg("--user", "bin", "alpine", "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+ if got, want := strings.TrimSpace(got), "/bin"; got != want {
+ t.Errorf("bad output from 'docker run'. Got %q; Want %q.", got, want)
+ }
+}
+
+// Test that exec always has HOME environment set, even when not set in run.
+func TestExecEnvHasHome(t *testing.T) {
+ // Base alpine image does not have any environment variables set.
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-home-test")
+
+ // We will check that HOME is set for root user, and also for a new
+ // non-root user we will create.
+ newUID := 1234
+ newHome := "/foo/bar"
+
+ // Create a new user with a home directory, and then sleep.
+ script := fmt.Sprintf(`
+ mkdir -p -m 777 %s && \
+ adduser foo -D -u %d -h %s && \
+ sleep 1000`, newHome, newUID, newHome)
+ if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $HOME", and expect to see "/root".
+ got, err := d.Exec("/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "/root"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+
+ // Execute the same as uid 123 and expect newHome.
+ got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := newHome; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
diff --git a/runsc/test/integration/integration.go b/test/e2e/integration.go
index 4cd5f6c24..4cd5f6c24 100644
--- a/runsc/test/integration/integration.go
+++ b/test/e2e/integration.go
diff --git a/runsc/test/integration/integration_test.go b/test/e2e/integration_test.go
index 7cef4b9dd..7cc0de129 100644
--- a/runsc/test/integration/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -18,10 +18,11 @@
// behaving properly, with various runsc commands. The container is killed and
// deleted at the end.
//
-// Setup instruction in runsc/test/README.md.
+// Setup instruction in test/README.md.
package integration
import (
+ "flag"
"fmt"
"net"
"net/http"
@@ -32,7 +33,8 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// httpRequestSucceeds sends a request to a given url and checks that the status is OK.
@@ -51,10 +53,10 @@ func httpRequestSucceeds(client http.Client, server string, port int) error {
// TestLifeCycle tests a basic Create/Start/Stop docker container life cycle.
func TestLifeCycle(t *testing.T) {
- if err := testutil.Pull("nginx"); err != nil {
+ if err := dockerutil.Pull("nginx"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("lifecycle-test")
+ d := dockerutil.MakeDocker("lifecycle-test")
if err := d.Create("-p", "80", "nginx"); err != nil {
t.Fatal("docker create failed:", err)
}
@@ -87,15 +89,15 @@ func TestLifeCycle(t *testing.T) {
func TestPauseResume(t *testing.T) {
const img = "gcr.io/gvisor-presubmit/python-hello"
- if !testutil.IsPauseResumeSupported() {
- t.Log("Pause/resume is not supported, skipping test.")
+ if !testutil.IsCheckpointSupported() {
+ t.Log("Checkpoint is not supported, skipping test.")
return
}
- if err := testutil.Pull(img); err != nil {
+ if err := dockerutil.Pull(img); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("pause-resume-test")
+ d := dockerutil.MakeDocker("pause-resume-test")
if err := d.Run("-p", "8080", img); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -151,14 +153,15 @@ func TestPauseResume(t *testing.T) {
func TestCheckpointRestore(t *testing.T) {
const img = "gcr.io/gvisor-presubmit/python-hello"
- if !testutil.IsPauseResumeSupported() {
+ if !testutil.IsCheckpointSupported() {
t.Log("Pause/resume is not supported, skipping test.")
return
}
- if err := testutil.Pull(img); err != nil {
+
+ if err := dockerutil.Pull(img); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("save-restore-test")
+ d := dockerutil.MakeDocker("save-restore-test")
if err := d.Run("-p", "8080", img); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -196,7 +199,7 @@ func TestCheckpointRestore(t *testing.T) {
// Create client and server that talk to each other using the local IP.
func TestConnectToSelf(t *testing.T) {
- d := testutil.MakeDocker("connect-to-self-test")
+ d := dockerutil.MakeDocker("connect-to-self-test")
// Creates server that replies "server" and exists. Sleeps at the end because
// 'docker exec' gets killed if the init process exists before it can finish.
@@ -228,10 +231,10 @@ func TestConnectToSelf(t *testing.T) {
}
func TestMemLimit(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
cmd := "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'"
out, err := d.RunFg("--memory=500MB", "alpine", "sh", "-c", cmd)
if err != nil {
@@ -258,10 +261,10 @@ func TestMemLimit(t *testing.T) {
}
func TestNumCPU(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
cmd := "cat /proc/cpuinfo | grep 'processor.*:' | wc -l"
out, err := d.RunFg("--cpuset-cpus=0", "alpine", "sh", "-c", cmd)
if err != nil {
@@ -280,10 +283,10 @@ func TestNumCPU(t *testing.T) {
// TestJobControl tests that job control characters are handled properly.
func TestJobControl(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("job-control-test")
+ d := dockerutil.MakeDocker("job-control-test")
// Start the container with an attached PTY.
_, ptmx, err := d.RunWithPty("alpine", "sh")
@@ -328,10 +331,10 @@ func TestJobControl(t *testing.T) {
// TestTmpFile checks that files inside '/tmp' are not overridden. In addition,
// it checks that working dir is created if it doesn't exit.
func TestTmpFile(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("tmp-file-test")
+ d := dockerutil.MakeDocker("tmp-file-test")
if err := d.Run("-w=/tmp/foo/bar", "--read-only", "alpine", "touch", "/tmp/foo/bar/file"); err != nil {
t.Fatal("docker run failed:", err)
}
@@ -339,6 +342,7 @@ func TestTmpFile(t *testing.T) {
}
func TestMain(m *testing.M) {
- testutil.EnsureSupportedDockerVersion()
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
os.Exit(m.Run())
}
diff --git a/runsc/test/integration/regression_test.go b/test/e2e/regression_test.go
index fb68dda99..2488be383 100644
--- a/runsc/test/integration/regression_test.go
+++ b/test/e2e/regression_test.go
@@ -18,7 +18,7 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
)
// Test that UDS can be created using overlay when parent directory is in lower
@@ -27,10 +27,10 @@ import (
// Prerequisite: the directory where the socket file is created must not have
// been open for write before bind(2) is called.
func TestBindOverlay(t *testing.T) {
- if err := testutil.Pull("ubuntu:trusty"); err != nil {
+ if err := dockerutil.Pull("ubuntu:trusty"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("bind-overlay-test")
+ d := dockerutil.MakeDocker("bind-overlay-test")
cmd := "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p"
got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd)
diff --git a/runsc/test/image/BUILD b/test/image/BUILD
index 58758fde5..09b0a0ad5 100644
--- a/runsc/test/image/BUILD
+++ b/test/image/BUILD
@@ -1,9 +1,8 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-load("//runsc/test:build_defs.bzl", "runtime_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(licenses = ["notice"])
-runtime_test(
+go_test(
name = "image_test",
size = "large",
srcs = [
@@ -21,11 +20,15 @@ runtime_test(
"manual",
"local",
],
- deps = ["//runsc/test/testutil"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//runsc/dockerutil",
+ "//runsc/testutil",
+ ],
)
go_library(
name = "image",
srcs = ["image.go"],
- importpath = "gvisor.dev/gvisor/runsc/test/image",
+ importpath = "gvisor.dev/gvisor/test/image",
)
diff --git a/runsc/test/image/image.go b/test/image/image.go
index 297f1ab92..297f1ab92 100644
--- a/runsc/test/image/image.go
+++ b/test/image/image.go
diff --git a/runsc/test/image/image_test.go b/test/image/image_test.go
index ddaa2c13b..d0dcb1861 100644
--- a/runsc/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -14,14 +14,15 @@
// Package image provides end-to-end image tests for runsc.
-// Each test calls docker commands to start up a container, and tests that it is
-// behaving properly, like connecting to a port or looking at the output. The
-// container is killed and deleted at the end.
+// Each test calls docker commands to start up a container, and tests that it
+// is behaving properly, like connecting to a port or looking at the output.
+// The container is killed and deleted at the end.
//
-// Setup instruction in runsc/test/README.md.
+// Setup instruction in test/README.md.
package image
import (
+ "flag"
"fmt"
"io/ioutil"
"log"
@@ -32,11 +33,12 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func TestHelloWorld(t *testing.T) {
- d := testutil.MakeDocker("hello-test")
+ d := dockerutil.MakeDocker("hello-test")
if err := d.Run("hello-world"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -100,18 +102,18 @@ func testHTTPServer(t *testing.T, port int) {
}
func TestHttpd(t *testing.T) {
- if err := testutil.Pull("httpd"); err != nil {
+ if err := dockerutil.Pull("httpd"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("http-test")
+ d := dockerutil.MakeDocker("http-test")
- dir, err := testutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
// Start the container.
- mountArg := testutil.MountArg(dir, "/usr/local/apache2/htdocs", testutil.ReadOnly)
+ mountArg := dockerutil.MountArg(dir, "/usr/local/apache2/htdocs", dockerutil.ReadOnly)
if err := d.Run("-p", "80", mountArg, "httpd"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -132,18 +134,18 @@ func TestHttpd(t *testing.T) {
}
func TestNginx(t *testing.T) {
- if err := testutil.Pull("nginx"); err != nil {
+ if err := dockerutil.Pull("nginx"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("net-test")
+ d := dockerutil.MakeDocker("net-test")
- dir, err := testutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
// Start the container.
- mountArg := testutil.MountArg(dir, "/usr/share/nginx/html", testutil.ReadOnly)
+ mountArg := dockerutil.MountArg(dir, "/usr/share/nginx/html", dockerutil.ReadOnly)
if err := d.Run("-p", "80", mountArg, "nginx"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -164,10 +166,10 @@ func TestNginx(t *testing.T) {
}
func TestMysql(t *testing.T) {
- if err := testutil.Pull("mysql"); err != nil {
+ if err := dockerutil.Pull("mysql"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("mysql-test")
+ d := dockerutil.MakeDocker("mysql-test")
// Start the container.
if err := d.Run("-e", "MYSQL_ROOT_PASSWORD=foobar123", "mysql"); err != nil {
@@ -180,8 +182,8 @@ func TestMysql(t *testing.T) {
t.Fatalf("docker.WaitForOutput() timeout: %v", err)
}
- client := testutil.MakeDocker("mysql-client-test")
- dir, err := testutil.PrepareFiles("mysql.sql")
+ client := dockerutil.MakeDocker("mysql-client-test")
+ dir, err := dockerutil.PrepareFiles("mysql.sql")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -189,8 +191,8 @@ func TestMysql(t *testing.T) {
// Tell mysql client to connect to the server and execute the file in verbose
// mode to verify the output.
args := []string{
- testutil.LinkArg(&d, "mysql"),
- testutil.MountArg(dir, "/sql", testutil.ReadWrite),
+ dockerutil.LinkArg(&d, "mysql"),
+ dockerutil.MountArg(dir, "/sql", dockerutil.ReadWrite),
"mysql",
"mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql",
}
@@ -212,10 +214,10 @@ func TestPythonHello(t *testing.T) {
// TODO(b/136503277): Once we have more complete python runtime tests,
// we can drop this one.
const img = "gcr.io/gvisor-presubmit/python-hello"
- if err := testutil.Pull(img); err != nil {
+ if err := dockerutil.Pull(img); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("python-hello-test")
+ d := dockerutil.MakeDocker("python-hello-test")
if err := d.Run("-p", "8080", img); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -244,10 +246,10 @@ func TestPythonHello(t *testing.T) {
}
func TestTomcat(t *testing.T) {
- if err := testutil.Pull("tomcat:8.0"); err != nil {
+ if err := dockerutil.Pull("tomcat:8.0"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("tomcat-test")
+ d := dockerutil.MakeDocker("tomcat-test")
if err := d.Run("-p", "8080", "tomcat:8.0"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -276,12 +278,12 @@ func TestTomcat(t *testing.T) {
}
func TestRuby(t *testing.T) {
- if err := testutil.Pull("ruby"); err != nil {
+ if err := dockerutil.Pull("ruby"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("ruby-test")
+ d := dockerutil.MakeDocker("ruby-test")
- dir, err := testutil.PrepareFiles("ruby.rb", "ruby.sh")
+ dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -289,7 +291,7 @@ func TestRuby(t *testing.T) {
t.Fatalf("os.Chmod(%q, 0333) failed: %v", dir, err)
}
- if err := d.Run("-p", "8080", testutil.MountArg(dir, "/src", testutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil {
+ if err := d.Run("-p", "8080", dockerutil.MountArg(dir, "/src", dockerutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
defer d.CleanUp()
@@ -324,10 +326,10 @@ func TestRuby(t *testing.T) {
}
func TestStdio(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("stdio-test")
+ d := dockerutil.MakeDocker("stdio-test")
wantStdout := "hello stdout"
wantStderr := "bonjour stderr"
@@ -345,6 +347,7 @@ func TestStdio(t *testing.T) {
}
func TestMain(m *testing.M) {
- testutil.EnsureSupportedDockerVersion()
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
os.Exit(m.Run())
}
diff --git a/runsc/test/image/latin10k.txt b/test/image/latin10k.txt
index 61341e00b..61341e00b 100644
--- a/runsc/test/image/latin10k.txt
+++ b/test/image/latin10k.txt
diff --git a/runsc/test/image/mysql.sql b/test/image/mysql.sql
index 51554b98d..51554b98d 100644
--- a/runsc/test/image/mysql.sql
+++ b/test/image/mysql.sql
diff --git a/runsc/test/image/ruby.rb b/test/image/ruby.rb
index aced49c6d..aced49c6d 100644
--- a/runsc/test/image/ruby.rb
+++ b/test/image/ruby.rb
diff --git a/runsc/test/image/ruby.sh b/test/image/ruby.sh
index ebe8d5b0e..ebe8d5b0e 100644
--- a/runsc/test/image/ruby.sh
+++ b/test/image/ruby.sh
diff --git a/runsc/test/root/BUILD b/test/root/BUILD
index 500ef7b8e..d5dd9bca2 100644
--- a/runsc/test/root/BUILD
+++ b/test/root/BUILD
@@ -5,7 +5,7 @@ package(licenses = ["notice"])
go_library(
name = "root",
srcs = ["root.go"],
- importpath = "gvisor.dev/gvisor/runsc/test/root",
+ importpath = "gvisor.dev/gvisor/test/root",
)
go_test(
@@ -15,6 +15,11 @@ go_test(
"cgroup_test.go",
"chroot_test.go",
"crictl_test.go",
+ "main_test.go",
+ "oom_score_adj_test.go",
+ ],
+ data = [
+ "//runsc",
],
embed = [":root"],
tags = [
@@ -23,11 +28,17 @@ go_test(
"manual",
"local",
],
+ visibility = ["//:sandbox"],
deps = [
+ "//runsc/boot",
"//runsc/cgroup",
+ "//runsc/container",
+ "//runsc/criutil",
+ "//runsc/dockerutil",
"//runsc/specutils",
- "//runsc/test/root/testdata",
- "//runsc/test/testutil",
+ "//runsc/testutil",
+ "//test/root/testdata",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
],
)
diff --git a/runsc/test/root/cgroup_test.go b/test/root/cgroup_test.go
index 5392dc6e0..76f1e4f2a 100644
--- a/runsc/test/root/cgroup_test.go
+++ b/test/root/cgroup_test.go
@@ -26,7 +26,8 @@ import (
"testing"
"gvisor.dev/gvisor/runsc/cgroup"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func verifyPid(pid int, path string) error {
@@ -56,11 +57,17 @@ func verifyPid(pid int, path string) error {
// TestCgroup sets cgroup options and checks that cgroup was properly configured.
func TestCgroup(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
+ // This is not a comprehensive list of attributes.
+ //
+ // Note that we are specifically missing cpusets, which fail if specified.
+ // In any case, it's unclear if cpusets can be reliably tested here: these
+ // are often run on a single core virtual machine, and there is only a single
+ // CPU available in our current set, and every container's set.
attrs := []struct {
arg string
ctrl string
@@ -87,18 +94,6 @@ func TestCgroup(t *testing.T) {
want: "3000",
},
{
- arg: "--cpuset-cpus=0",
- ctrl: "cpuset",
- file: "cpuset.cpus",
- want: "0",
- },
- {
- arg: "--cpuset-mems=0",
- ctrl: "cpuset",
- file: "cpuset.mems",
- want: "0",
- },
- {
arg: "--kernel-memory=100MB",
ctrl: "memory",
file: "memory.kmem.limit_in_bytes",
@@ -197,10 +192,10 @@ func TestCgroup(t *testing.T) {
}
func TestCgroupParent(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
parent := testutil.RandomName("runsc")
if err := d.Run("--cgroup-parent", parent, "alpine", "sleep", "10000"); err != nil {
diff --git a/runsc/test/root/chroot_test.go b/test/root/chroot_test.go
index d0f236580..be0f63d18 100644
--- a/runsc/test/root/chroot_test.go
+++ b/test/root/chroot_test.go
@@ -12,33 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package root is used for tests that requires sysadmin privileges run. First,
-// follow the setup instruction in runsc/test/README.md. To run these tests:
-//
-// bazel build //runsc/test/root:root_test
-// root_test=$(find -L ./bazel-bin/ -executable -type f -name root_test | grep __main__)
-// sudo RUNSC_RUNTIME=runsc-test ${root_test}
+// Package root is used for tests that requires sysadmin privileges run.
package root
import (
"fmt"
"io/ioutil"
- "os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"testing"
- "github.com/syndtr/gocapability/capability"
- "gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
)
// TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned
// up after the sandbox is destroyed.
func TestChroot(t *testing.T) {
- d := testutil.MakeDocker("chroot-test")
+ d := dockerutil.MakeDocker("chroot-test")
if err := d.Run("alpine", "sleep", "10000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -84,7 +76,7 @@ func TestChroot(t *testing.T) {
}
func TestChrootGofer(t *testing.T) {
- d := testutil.MakeDocker("chroot-test")
+ d := dockerutil.MakeDocker("chroot-test")
if err := d.Run("alpine", "sleep", "10000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -148,14 +140,3 @@ func TestChrootGofer(t *testing.T) {
}
}
}
-
-func TestMain(m *testing.M) {
- testutil.EnsureSupportedDockerVersion()
-
- if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) {
- fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.")
- os.Exit(1)
- }
-
- os.Exit(m.Run())
-}
diff --git a/runsc/test/root/crictl_test.go b/test/root/crictl_test.go
index 515ae2df1..3f90c4c6a 100644
--- a/runsc/test/root/crictl_test.go
+++ b/test/root/crictl_test.go
@@ -29,14 +29,17 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/runsc/criutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/root/testdata"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/test/root/testdata"
)
// Tests for crictl have to be run as root (rather than in a user namespace)
// because crictl creates named network namespaces in /var/run/netns/.
+// TestCrictlSanity refers to b/112433158.
func TestCrictlSanity(t *testing.T) {
// Setup containerd and crictl.
crictl, cleanup, err := setup(t)
@@ -60,6 +63,7 @@ func TestCrictlSanity(t *testing.T) {
}
}
+// TestMountPaths refers to b/117635704.
func TestMountPaths(t *testing.T) {
// Setup containerd and crictl.
crictl, cleanup, err := setup(t)
@@ -83,6 +87,7 @@ func TestMountPaths(t *testing.T) {
}
}
+// TestMountPaths refers to b/118728671.
func TestMountOverSymlinks(t *testing.T) {
// Setup containerd and crictl.
crictl, cleanup, err := setup(t)
@@ -121,11 +126,64 @@ func TestMountOverSymlinks(t *testing.T) {
}
}
+// TestHomeDir tests that the HOME environment variable is set for
+// multi-containers.
+func TestHomeDir(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
+ contSpec := testdata.SimpleSpec("root", "k8s.gcr.io/busybox", []string{"sleep", "1000"})
+ podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, contSpec)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t.Run("root container", func(t *testing.T) {
+ out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
+ t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want)
+ }
+ })
+
+ t.Run("sub-container", func(t *testing.T) {
+ // Create a sub container in the same pod.
+ subContSpec := testdata.SimpleSpec("subcontainer", "k8s.gcr.io/busybox", []string{"sleep", "1000"})
+ subContID, err := crictl.StartContainer(podID, "k8s.gcr.io/busybox", testdata.Sandbox, subContSpec)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ out, err := crictl.Exec(subContID, "sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
+ t.Fatalf("Home directory invalid. Got %q, Want: %q", got, want)
+ }
+
+ if err := crictl.StopContainer(subContID); err != nil {
+ t.Fatal(err)
+ }
+ })
+
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatal(err)
+ }
+
+}
+
// setup sets up before a test. Specifically it:
// * Creates directories and a socket for containerd to utilize.
// * Runs containerd and waits for it to reach a "ready" state for testing.
// * Returns a cleanup function that should be called at the end of the test.
-func setup(t *testing.T) (*testutil.Crictl, func(), error) {
+func setup(t *testing.T) (*criutil.Crictl, func(), error) {
var cleanups []func()
cleanupFunc := func() {
for i := len(cleanups) - 1; i >= 0; i-- {
@@ -149,12 +207,19 @@ func setup(t *testing.T) (*testutil.Crictl, func(), error) {
cleanups = append(cleanups, func() { os.RemoveAll(containerdState) })
sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock")
- // Start containerd.
- config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(getRunsc()))
+ // We rewrite a configuration. This is based on the current docker
+ // configuration for the runtime under test.
+ runtime, err := dockerutil.RuntimePath()
+ if err != nil {
+ t.Fatalf("error discovering runtime path: %v", err)
+ }
+ config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(runtime))
if err != nil {
t.Fatalf("failed to write containerd config")
}
cleanups = append(cleanups, func() { os.RemoveAll(config) })
+
+ // Start containerd.
containerd := exec.Command(getContainerd(),
"--config", config,
"--log-level", "debug",
@@ -191,11 +256,11 @@ func setup(t *testing.T) (*testutil.Crictl, func(), error) {
})
cleanup.Release()
- return testutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil
+ return criutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil
}
// httpGet GETs the contents of a file served from a pod on port 80.
-func httpGet(crictl *testutil.Crictl, podID, filePath string) error {
+func httpGet(crictl *criutil.Crictl, podID, filePath string) error {
// Get the IP of the httpd server.
ip, err := crictl.PodIP(podID)
if err != nil {
@@ -222,21 +287,9 @@ func httpGet(crictl *testutil.Crictl, podID, filePath string) error {
}
func getContainerd() string {
- // Bazel doesn't pass PATH through, assume the location of containerd
- // unless specified by environment variable.
- c := os.Getenv("CONTAINERD_PATH")
- if c == "" {
+ // Use the local path if it exists, otherwise, use the system one.
+ if _, err := os.Stat("/usr/local/bin/containerd"); err == nil {
return "/usr/local/bin/containerd"
}
- return c
-}
-
-func getRunsc() string {
- // Bazel doesn't pass PATH through, assume the location of runsc unless
- // specified by environment variable.
- c := os.Getenv("RUNSC_EXEC")
- if c == "" {
- return "/tmp/runsc-test/runsc"
- }
- return c
+ return "/usr/bin/containerd"
}
diff --git a/test/root/main_test.go b/test/root/main_test.go
new file mode 100644
index 000000000..d74dec85f
--- /dev/null
+++ b/test/root/main_test.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "testing"
+
+ "github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// TestMain is the main function for root tests. This function checks the
+// supported docker version, required capabilities, and configures the executable
+// path for runsc.
+func TestMain(m *testing.M) {
+ flag.Parse()
+
+ if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) {
+ fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.")
+ os.Exit(1)
+ }
+
+ dockerutil.EnsureSupportedDockerVersion()
+
+ // Configure exe for tests.
+ path, err := dockerutil.RuntimePath()
+ if err != nil {
+ panic(err.Error())
+ }
+ specutils.ExePath = path
+
+ os.Exit(m.Run())
+}
diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go
new file mode 100644
index 000000000..6cd378a1b
--- /dev/null
+++ b/test/root/oom_score_adj_test.go
@@ -0,0 +1,376 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "fmt"
+ "os"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+var (
+ maxOOMScoreAdj = 1000
+ highOOMScoreAdj = 500
+ lowOOMScoreAdj = -500
+ minOOMScoreAdj = -1000
+)
+
+// Tests for oom_score_adj have to be run as root (rather than in a user
+// namespace) because we need to adjust oom_score_adj for PIDs other than our
+// own and test values below 0.
+
+// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a
+// single container sandbox.
+func TestOOMScoreAdjSingle(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set.
+ OOMScoreAdj *int
+ }{
+ {
+ Name: "max",
+ OOMScoreAdj: &maxOOMScoreAdj,
+ },
+ {
+ Name: "high",
+ OOMScoreAdj: &highOOMScoreAdj,
+ },
+ {
+ Name: "low",
+ OOMScoreAdj: &lowOOMScoreAdj,
+ },
+ {
+ Name: "min",
+ OOMScoreAdj: &minOOMScoreAdj,
+ },
+ {
+ Name: "nil",
+ OOMScoreAdj: &parentOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ id := testutil.UniqueContainerID()
+ s := testutil.NewSpecWithArgs("sleep", "1000")
+ s.Process.OOMScoreAdj = testCase.OOMScoreAdj
+
+ conf := testutil.TestConfig()
+ containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id})
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ c := containers[0]
+
+ // Verify the gofer's oom_score_adj
+ if testCase.OOMScoreAdj != nil {
+ goferScore, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if goferScore != *testCase.OOMScoreAdj {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", goferScore, *testCase.OOMScoreAdj)
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := c.Sandbox.Pid
+ sandboxScore, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if sandboxScore != *testCase.OOMScoreAdj {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", sandboxScore, *testCase.OOMScoreAdj)
+ }
+ }
+ })
+ }
+}
+
+// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a
+// multi-container sandbox.
+func TestOOMScoreAdjMulti(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set. One value for each container. The first value is the
+ // root container.
+ OOMScoreAdj []*int
+
+ // Expected is the expected oom_score_adj of the sandbox. If nil, then
+ // this value is ignored.
+ Expected *int
+
+ // Remove is a set of container indexes to remove from the sandbox.
+ Remove []int
+
+ // ExpectedAfterRemove is the expected oom_score_adj of the sandbox
+ // after containers are removed. Ignored if nil.
+ ExpectedAfterRemove *int
+ }{
+ // A single container CRI test case. This should not happen in
+ // practice as there should be at least one container besides the pause
+ // container. However, we include a test case to ensure sane behavior.
+ {
+ Name: "single",
+ OOMScoreAdj: []*int{&highOOMScoreAdj},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_no_value",
+ OOMScoreAdj: []*int{nil, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_non_nil_root",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &highOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_min_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_max_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ },
+ {
+ Name: "remove_adjusted",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove highOOMScoreAdj container.
+ Remove: []int{2},
+ ExpectedAfterRemove: &maxOOMScoreAdj,
+ },
+ {
+ // This test removes all non-root sandboxes with a specified oomScoreAdj.
+ Name: "remove_to_nil",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, &lowOOMScoreAdj},
+ Expected: &lowOOMScoreAdj,
+ // Remove lowOOMScoreAdj container.
+ Remove: []int{2},
+ // The oom_score_adj expected after remove is that of the parent process.
+ ExpectedAfterRemove: &parentOOMScoreAdj,
+ },
+ {
+ Name: "remove_no_effect",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove the maxOOMScoreAdj container.
+ Remove: []int{1},
+ ExpectedAfterRemove: &highOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ var cmds [][]string
+ var oomScoreAdj []*int
+ var toRemove []string
+
+ for _, oomScore := range testCase.OOMScoreAdj {
+ oomScoreAdj = append(oomScoreAdj, oomScore)
+ cmds = append(cmds, []string{"sleep", "100"})
+ }
+
+ specs, ids := createSpecs(cmds...)
+ for i, spec := range specs {
+ // Ensure the correct value is set, including no value.
+ spec.Process.OOMScoreAdj = oomScoreAdj[i]
+
+ for _, j := range testCase.Remove {
+ if i == j {
+ toRemove = append(toRemove, ids[i])
+ }
+ }
+ }
+
+ conf := testutil.TestConfig()
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ for i, c := range containers {
+ if oomScoreAdj[i] != nil {
+ // Verify the gofer's oom_score_adj
+ score, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if score != *oomScoreAdj[i] {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", score, *oomScoreAdj[i])
+ }
+ }
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := containers[0].Sandbox.Pid
+ if testCase.Expected != nil {
+ score, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if score != *testCase.Expected {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", score, *testCase.Expected)
+ }
+ }
+
+ if len(toRemove) == 0 {
+ return
+ }
+
+ // Remove containers.
+ for _, removeID := range toRemove {
+ for _, c := range containers {
+ if c.ID == removeID {
+ c.Destroy()
+ }
+ }
+ }
+
+ // Check the new adjusted oom_score_adj.
+ if testCase.ExpectedAfterRemove != nil {
+ scoreAfterRemove, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if scoreAfterRemove != *testCase.ExpectedAfterRemove {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", scoreAfterRemove, *testCase.ExpectedAfterRemove)
+ }
+ }
+ })
+ }
+}
+
+func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
+ var specs []*specs.Spec
+ var ids []string
+ rootID := testutil.UniqueContainerID()
+
+ for i, cmd := range cmds {
+ spec := testutil.NewSpecWithArgs(cmd...)
+ if i == 0 {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox,
+ }
+ ids = append(ids, rootID)
+ } else {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer,
+ specutils.ContainerdSandboxIDAnnotation: rootID,
+ }
+ ids = append(ids, testutil.UniqueContainerID())
+ }
+ specs = append(specs, spec)
+ }
+ return specs, ids
+}
+
+func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) {
+ // Setup root dir if one hasn't been provided.
+ if len(conf.RootDir) == 0 {
+ rootDir, err := testutil.SetupRootDir()
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating root dir: %v", err)
+ }
+ conf.RootDir = rootDir
+ }
+
+ var containers []*container.Container
+ var bundles []string
+ cleanup := func() {
+ for _, c := range containers {
+ c.Destroy()
+ }
+ for _, b := range bundles {
+ os.RemoveAll(b)
+ }
+ os.RemoveAll(conf.RootDir)
+ }
+ for i, spec := range specs {
+ bundleDir, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error setting up container: %v", err)
+ }
+ bundles = append(bundles, bundleDir)
+
+ args := container.Args{
+ ID: ids[i],
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := container.New(conf, args)
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error creating container: %v", err)
+ }
+ containers = append(containers, cont)
+
+ if err := cont.Start(conf); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error starting container: %v", err)
+ }
+ }
+ return containers, cleanup, nil
+}
diff --git a/test/root/root.go b/test/root/root.go
new file mode 100644
index 000000000..0f1d29faf
--- /dev/null
+++ b/test/root/root.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package root is used for tests that requires sysadmin privileges run. First,
+// follow the setup instruction in runsc/test/README.md. You should also have
+// docker, containerd, and crictl installed. To run these tests from the
+// project root directory:
+//
+// ./scripts/root_tests.sh
+package root
diff --git a/runsc/test/root/testdata/BUILD b/test/root/testdata/BUILD
index 80dc5f214..125633680 100644
--- a/runsc/test/root/testdata/BUILD
+++ b/test/root/testdata/BUILD
@@ -10,8 +10,9 @@ go_library(
"httpd.go",
"httpd_mount_paths.go",
"sandbox.go",
+ "simple.go",
],
- importpath = "gvisor.dev/gvisor/runsc/test/root/testdata",
+ importpath = "gvisor.dev/gvisor/test/root/testdata",
visibility = [
"//visibility:public",
],
diff --git a/runsc/test/root/testdata/busybox.go b/test/root/testdata/busybox.go
index e4dbd2843..e4dbd2843 100644
--- a/runsc/test/root/testdata/busybox.go
+++ b/test/root/testdata/busybox.go
diff --git a/runsc/test/root/testdata/containerd_config.go b/test/root/testdata/containerd_config.go
index e12f1ec88..e12f1ec88 100644
--- a/runsc/test/root/testdata/containerd_config.go
+++ b/test/root/testdata/containerd_config.go
diff --git a/runsc/test/root/testdata/httpd.go b/test/root/testdata/httpd.go
index 45d5e33d4..45d5e33d4 100644
--- a/runsc/test/root/testdata/httpd.go
+++ b/test/root/testdata/httpd.go
diff --git a/runsc/test/root/testdata/httpd_mount_paths.go b/test/root/testdata/httpd_mount_paths.go
index ac3f4446a..ac3f4446a 100644
--- a/runsc/test/root/testdata/httpd_mount_paths.go
+++ b/test/root/testdata/httpd_mount_paths.go
diff --git a/runsc/test/root/testdata/sandbox.go b/test/root/testdata/sandbox.go
index 0db210370..0db210370 100644
--- a/runsc/test/root/testdata/sandbox.go
+++ b/test/root/testdata/sandbox.go
diff --git a/test/root/testdata/simple.go b/test/root/testdata/simple.go
new file mode 100644
index 000000000..1cca53f0c
--- /dev/null
+++ b/test/root/testdata/simple.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testdata
+
+import (
+ "encoding/json"
+ "fmt"
+)
+
+// SimpleSpec returns a JSON config for a simple container that runs the
+// specified command in the specified image.
+func SimpleSpec(name, image string, cmd []string) string {
+ cmds, err := json.Marshal(cmd)
+ if err != nil {
+ // This shouldn't happen.
+ panic(err)
+ }
+ return fmt.Sprintf(`
+{
+ "metadata": {
+ "name": %q
+ },
+ "image": {
+ "image": %q
+ },
+ "command": %s
+ }
+`, name, image, cmds)
+}
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index e85804a83..2e125525b 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -1,25 +1,53 @@
# These packages are used to run language runtime tests inside gVisor sandboxes.
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-load("//runsc/test:build_defs.bzl", "runtime_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
+load("//test/runtimes:build_defs.bzl", "runtime_test")
package(licenses = ["notice"])
-go_library(
- name = "runtimes",
- srcs = ["runtimes.go"],
- importpath = "gvisor.dev/gvisor/test/runtimes",
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["runner.go"],
+ deps = [
+ "//runsc/dockerutil",
+ "//runsc/testutil",
+ ],
+)
+
+runtime_test(
+ blacklist_file = "blacklist_go1.12.csv",
+ image = "gcr.io/gvisor-presubmit/go1.12",
+ lang = "go",
+)
+
+runtime_test(
+ blacklist_file = "blacklist_java11.csv",
+ image = "gcr.io/gvisor-presubmit/java11",
+ lang = "java",
+)
+
+runtime_test(
+ blacklist_file = "blacklist_nodejs12.4.0.csv",
+ image = "gcr.io/gvisor-presubmit/nodejs12.4.0",
+ lang = "nodejs",
+)
+
+runtime_test(
+ blacklist_file = "blacklist_php7.3.6.csv",
+ image = "gcr.io/gvisor-presubmit/php7.3.6",
+ lang = "php",
)
runtime_test(
- name = "runtimes_test",
+ blacklist_file = "blacklist_python3.7.3.csv",
+ image = "gcr.io/gvisor-presubmit/python3.7.3",
+ lang = "python",
+)
+
+go_test(
+ name = "blacklist_test",
size = "small",
- srcs = ["runtimes_test.go"],
- embed = [":runtimes"],
- tags = [
- # Requires docker and runsc to be configured before the test runs.
- "manual",
- "local",
- ],
- deps = ["//runsc/test/testutil"],
+ srcs = ["blacklist_test.go"],
+ embed = [":runner"],
)
diff --git a/test/runtimes/README.md b/test/runtimes/README.md
index 34d3507be..e41e78f77 100644
--- a/test/runtimes/README.md
+++ b/test/runtimes/README.md
@@ -16,10 +16,11 @@ The following runtimes are currently supported:
1) [Install and configure Docker](https://docs.docker.com/install/)
-2) Build each Docker container from the runtimes directory:
+2) Build each Docker container from the runtimes/images directory:
```bash
-$ docker build -f $LANG/Dockerfile [-t $NAME] .
+$ cd images
+$ docker build -f Dockerfile_$LANG [-t $NAME] .
```
### Testing:
diff --git a/test/runtimes/blacklist_go1.12.csv b/test/runtimes/blacklist_go1.12.csv
new file mode 100644
index 000000000..8c8ae0c5d
--- /dev/null
+++ b/test/runtimes/blacklist_go1.12.csv
@@ -0,0 +1,16 @@
+test name,bug id,comment
+cgo_errors,,FLAKY
+cgo_test,,FLAKY
+go_test:cmd/go,,FLAKY
+go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing
+go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug.
+go_test:os,b/118780122,we have a pollable filesystem but that's a surprise
+go_test:os/signal,b/118780860,/dev/pts not properly supported
+go_test:runtime,b/118782341,sigtrap not reported or caught or something
+go_test:syscall,b/118781998,bad bytes -- bad mem addr
+race,b/118782931,thread sanitizer. Works as intended: b/62219744.
+runtime:cpu124,b/118778254,segmentation fault
+test:0_1,,FLAKY
+testasan,,
+testcarchive,b/118782924,no sigpipe
+testshared,,FLAKY
diff --git a/test/runtimes/blacklist_java11.csv b/test/runtimes/blacklist_java11.csv
new file mode 100644
index 000000000..c012e5a56
--- /dev/null
+++ b/test/runtimes/blacklist_java11.csv
@@ -0,0 +1,126 @@
+test name,bug id,comment
+com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker
+com/sun/jdi/NashornPopFrameTest.java,,
+com/sun/jdi/ProcessAttachTest.java,,
+com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker
+com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,,
+com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,,
+com/sun/tools/attach/AttachSelf.java,,
+com/sun/tools/attach/BasicTests.java,,
+com/sun/tools/attach/PermissionTest.java,,
+com/sun/tools/attach/StartManagementAgent.java,,
+com/sun/tools/attach/TempDirTest.java,,
+com/sun/tools/attach/modules/Driver.java,,
+java/lang/Character/CheckScript.java,,Fails in Docker
+java/lang/Character/CheckUnicode.java,,Fails in Docker
+java/lang/Class/GetPackageBootLoaderChildLayer.java,,
+java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker
+java/lang/String/nativeEncoding/StringPlatformChars.java,,
+java/net/DatagramSocket/ReuseAddressTest.java,,
+java/net/DatagramSocket/SendDatagramToBadAddress.java,b/78473345,
+java/net/Inet4Address/PingThis.java,,
+java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103,
+java/net/MulticastSocket/MulticastTTL.java,,
+java/net/MulticastSocket/Promiscuous.java,,
+java/net/MulticastSocket/SetLoopbackMode.java,,
+java/net/MulticastSocket/SetTTLAndGetTTL.java,,
+java/net/MulticastSocket/Test.java,,
+java/net/MulticastSocket/TestDefaults.java,,
+java/net/MulticastSocket/TimeToLive.java,,
+java/net/NetworkInterface/NetworkInterfaceStreamTest.java,,
+java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported
+java/net/Socket/TrafficClass.java,b/78527818,Not supported on gVisor
+java/net/Socket/UrgentDataTest.java,b/111515323,
+java/net/Socket/setReuseAddress/Basic.java,b/78519214,SO_REUSEADDR enabled by default
+java/net/SocketOption/OptionsTest.java,,Fails in Docker
+java/net/SocketOption/TcpKeepAliveTest.java,,
+java/net/SocketPermission/SocketPermissionTest.java,,
+java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker
+java/net/httpclient/RequestBuilderTest.java,,Fails in Docker
+java/net/httpclient/ShortResponseBody.java,,
+java/net/httpclient/ShortResponseBodyWithRetry.java,,
+java/nio/channels/AsyncCloseAndInterrupt.java,,
+java/nio/channels/AsynchronousServerSocketChannel/Basic.java,,
+java/nio/channels/AsynchronousSocketChannel/Basic.java,b/77921528,SO_KEEPALIVE is not settable
+java/nio/channels/DatagramChannel/BasicMulticastTests.java,,
+java/nio/channels/DatagramChannel/SocketOptionTests.java,,Fails in Docker
+java/nio/channels/DatagramChannel/UseDGWithIPv6.java,,
+java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker
+java/nio/channels/Selector/OutOfBand.java,,
+java/nio/channels/Selector/SelectWithConsumer.java,,Flaky
+java/nio/channels/ServerSocketChannel/SocketOptionTests.java,,
+java/nio/channels/SocketChannel/LingerOnClose.java,,
+java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901,
+java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker
+java/rmi/activation/Activatable/extLoadedImpl/ext.sh,,
+java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,,
+java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
+java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
+java/util/Calendar/JapaneseEraNameTest.java,,
+java/util/Currency/CurrencyTest.java,,Fails in Docker
+java/util/Currency/ValidateISO4217.java,,Fails in Docker
+java/util/Locale/LSRDataTest.java,,
+java/util/concurrent/locks/Lock/TimedAcquireLeak.java,,
+java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker
+java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,,
+java/util/logging/TestLoggerWeakRefLeak.java,,
+javax/imageio/AppletResourceTest.java,,
+javax/management/security/HashedPasswordFileTest.java,,
+javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker
+javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,,
+jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,,
+jdk/jfr/event/runtime/TestThreadParkEvent.java,,
+jdk/jfr/event/sampling/TestNative.java,,
+jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,,
+jdk/jfr/jcmd/TestJcmdConfigure.java,,
+jdk/jfr/jcmd/TestJcmdDump.java,,
+jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,,
+jdk/jfr/jcmd/TestJcmdDumpLimited.java,,
+jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,,
+jdk/jfr/jcmd/TestJcmdLegacy.java,,
+jdk/jfr/jcmd/TestJcmdSaveToFile.java,,
+jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,,
+jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,,
+jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,,
+jdk/jfr/jcmd/TestJcmdStartStopDefault.java,,
+jdk/jfr/jcmd/TestJcmdStartWithOptions.java,,
+jdk/jfr/jcmd/TestJcmdStartWithSettings.java,,
+jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,,
+jdk/jfr/jvm/TestJfrJavaBase.java,,
+jdk/jfr/startupargs/TestStartRecording.java,,
+jdk/modules/incubator/ImageModules.java,,
+jdk/net/Sockets/ExtOptionTest.java,,
+jdk/net/Sockets/QuickAckTest.java,,
+lib/security/cacerts/VerifyCACerts.java,,
+sun/management/jmxremote/bootstrap/CustomLauncherTest.java,,
+sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,,
+sun/management/jmxremote/bootstrap/LocalManagementTest.java,,
+sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,,
+sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,,
+sun/management/jmxremote/startstop/JMXStartStopTest.java,,
+sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,,
+sun/management/jmxremote/startstop/JMXStatusTest.java,,
+sun/text/resources/LocaleDataTest.java,,
+sun/tools/jcmd/TestJcmdSanity.java,,
+sun/tools/jhsdb/AlternateHashingTest.java,,
+sun/tools/jhsdb/BasicLauncherTest.java,,
+sun/tools/jhsdb/HeapDumpTest.java,,
+sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,,
+sun/tools/jinfo/BasicJInfoTest.java,,
+sun/tools/jinfo/JInfoTest.java,,
+sun/tools/jmap/BasicJMapTest.java,,
+sun/tools/jstack/BasicJStackTest.java,,
+sun/tools/jstack/DeadlockDetectionTest.java,,
+sun/tools/jstatd/TestJstatdExternalRegistry.java,,
+sun/tools/jstatd/TestJstatdPort.java,,Flaky
+sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky
+sun/util/calendar/zi/TestZoneInfo310.java,,
+tools/jar/modularJar/Basic.java,,
+tools/jar/multiRelease/Basic.java,,
+tools/jimage/JImageExtractTest.java,,
+tools/jimage/JImageTest.java,,
+tools/jlink/JLinkTest.java,,
+tools/jlink/plugins/IncludeLocalesPluginTest.java,,
+tools/jmod/hashes/HashesTest.java,,
+tools/launcher/BigJar.java,b/111611473,
+tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,,
diff --git a/test/runtimes/blacklist_nodejs12.4.0.csv b/test/runtimes/blacklist_nodejs12.4.0.csv
new file mode 100644
index 000000000..4ab4e2927
--- /dev/null
+++ b/test/runtimes/blacklist_nodejs12.4.0.csv
@@ -0,0 +1,47 @@
+test name,bug id,comment
+benchmark/test-benchmark-fs.js,,
+benchmark/test-benchmark-module.js,,
+benchmark/test-benchmark-napi.js,,
+doctool/test-make-doc.js,b/68848110,Expected to fail.
+fixtures/test-error-first-line-offset.js,,
+fixtures/test-fs-readfile-error.js,,
+fixtures/test-fs-stat-sync-overflow.js,,
+internet/test-dgram-broadcast-multi-process.js,,
+internet/test-dgram-multicast-multi-process.js,,
+internet/test-dgram-multicast-set-interface-lo.js,,
+parallel/test-cluster-dgram-reuse.js,b/64024294,
+parallel/test-dgram-bind-fd.js,b/132447356,
+parallel/test-dgram-create-socket-handle-fd.js,b/132447238,
+parallel/test-dgram-createSocket-type.js,b/68847739,
+parallel/test-dgram-socket-buffer-size.js,b/68847921,
+parallel/test-fs-access.js,,
+parallel/test-fs-write-stream-double-close.js,,
+parallel/test-fs-write-stream-throw-type-error.js,b/110226209,
+parallel/test-fs-write-stream.js,,
+parallel/test-http2-respond-file-error-pipe-offset.js,,
+parallel/test-os.js,,
+parallel/test-process-uid-gid.js,,
+pseudo-tty/test-assert-colors.js,,
+pseudo-tty/test-assert-no-color.js,,
+pseudo-tty/test-assert-position-indicator.js,,
+pseudo-tty/test-async-wrap-getasyncid-tty.js,,
+pseudo-tty/test-fatal-error.js,,
+pseudo-tty/test-handle-wrap-isrefed-tty.js,,
+pseudo-tty/test-readable-tty-keepalive.js,,
+pseudo-tty/test-set-raw-mode-reset-process-exit.js,,
+pseudo-tty/test-set-raw-mode-reset-signal.js,,
+pseudo-tty/test-set-raw-mode-reset.js,,
+pseudo-tty/test-stderr-stdout-handle-sigwinch.js,,
+pseudo-tty/test-stdout-read.js,,
+pseudo-tty/test-tty-color-support.js,,
+pseudo-tty/test-tty-isatty.js,,
+pseudo-tty/test-tty-stdin-call-end.js,,
+pseudo-tty/test-tty-stdin-end.js,,
+pseudo-tty/test-stdin-write.js,,
+pseudo-tty/test-tty-stdout-end.js,,
+pseudo-tty/test-tty-stdout-resize.js,,
+pseudo-tty/test-tty-stream-constructors.js,,
+pseudo-tty/test-tty-window-size.js,,
+pseudo-tty/test-tty-wrap.js,,
+pummel/test-net-pingpong.js,,
+pummel/test-vm-memleak.js,,
diff --git a/test/runtimes/blacklist_php7.3.6.csv b/test/runtimes/blacklist_php7.3.6.csv
new file mode 100644
index 000000000..456bf7487
--- /dev/null
+++ b/test/runtimes/blacklist_php7.3.6.csv
@@ -0,0 +1,29 @@
+test name,bug id,comment
+ext/intl/tests/bug77895.phpt,,
+ext/intl/tests/dateformat_bug65683_2.phpt,,
+ext/mbstring/tests/bug76319.phpt,,
+ext/mbstring/tests/bug76958.phpt,,
+ext/mbstring/tests/bug77025.phpt,,
+ext/mbstring/tests/bug77165.phpt,,
+ext/mbstring/tests/bug77454.phpt,,
+ext/mbstring/tests/mb_convert_encoding_leak.phpt,,
+ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,,
+ext/standard/tests/file/filetype_variation.phpt,,
+ext/standard/tests/file/fopen_variation19.phpt,,
+ext/standard/tests/file/php_fd_wrapper_01.phpt,,
+ext/standard/tests/file/php_fd_wrapper_02.phpt,,
+ext/standard/tests/file/php_fd_wrapper_03.phpt,,
+ext/standard/tests/file/php_fd_wrapper_04.phpt,,
+ext/standard/tests/file/realpath_bug77484.phpt,,
+ext/standard/tests/file/rename_variation.phpt,b/68717309,
+ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,,
+ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,,
+ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,,
+ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,,
+ext/standard/tests/network/bug20134.phpt,,
+tests/output/stream_isatty_err.phpt,b/68720279,
+tests/output/stream_isatty_in-err.phpt,b/68720282,
+tests/output/stream_isatty_in-out-err.phpt,,
+tests/output/stream_isatty_in-out.phpt,b/68720299,
+tests/output/stream_isatty_out-err.phpt,b/68720311,
+tests/output/stream_isatty_out.phpt,b/68720325,
diff --git a/test/runtimes/blacklist_python3.7.3.csv b/test/runtimes/blacklist_python3.7.3.csv
new file mode 100644
index 000000000..2b9947212
--- /dev/null
+++ b/test/runtimes/blacklist_python3.7.3.csv
@@ -0,0 +1,27 @@
+test name,bug id,comment
+test_asynchat,b/76031995,SO_REUSEADDR
+test_asyncio,,Fails on Docker.
+test_asyncore,b/76031995,SO_REUSEADDR
+test_epoll,,
+test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode.
+test_ftplib,,Fails in Docker
+test_httplib,b/76031995,SO_REUSEADDR
+test_imaplib,,
+test_logging,,
+test_multiprocessing_fork,,Flaky. Sometimes times out.
+test_multiprocessing_forkserver,,Flaky. Sometimes times out.
+test_multiprocessing_main_handling,,Flaky. Sometimes times out.
+test_multiprocessing_spawn,,Flaky. Sometimes times out.
+test_nntplib,b/76031995,tests should not set SO_REUSEADDR
+test_poplib,,Fails on Docker
+test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted
+test_pty,b/76157709,out of pty devices
+test_readline,b/76157709,out of pty devices
+test_resource,b/76174079,
+test_selectors,b/76116849,OSError not raised with epoll
+test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets
+test_socket,b/75983380,
+test_ssl,b/76031995,SO_REUSEADDR
+test_subprocess,,
+test_support,b/76031995,SO_REUSEADDR
+test_telnetlib,b/76031995,SO_REUSEADDR
diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go
new file mode 100644
index 000000000..52f49b984
--- /dev/null
+++ b/test/runtimes/blacklist_test.go
@@ -0,0 +1,37 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "flag"
+ "os"
+ "testing"
+)
+
+func TestMain(m *testing.M) {
+ flag.Parse()
+ os.Exit(m.Run())
+}
+
+// Test that the blacklist parses without error.
+func TestBlacklists(t *testing.T) {
+ bl, err := getBlacklist()
+ if err != nil {
+ t.Fatalf("error parsing blacklist: %v", err)
+ }
+ if *blacklistFile != "" && len(bl) == 0 {
+ t.Errorf("got empty blacklist for file %q", blacklistFile)
+ }
+}
diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl
new file mode 100644
index 000000000..7c11624b4
--- /dev/null
+++ b/test/runtimes/build_defs.bzl
@@ -0,0 +1,57 @@
+"""Defines a rule for runtime test targets."""
+
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+# runtime_test is a macro that will create targets to run the given test target
+# with different runtime options.
+def runtime_test(
+ lang,
+ image,
+ shard_count = 50,
+ size = "enormous",
+ blacklist_file = ""):
+ args = [
+ "--lang",
+ lang,
+ "--image",
+ image,
+ ]
+ data = [
+ ":runner",
+ ]
+ if blacklist_file != "":
+ args += ["--blacklist_file", "test/runtimes/" + blacklist_file]
+ data += [blacklist_file]
+
+ # Add a test that the blacklist parses correctly.
+ blacklist_test(lang, blacklist_file)
+
+ sh_test(
+ name = lang + "_test",
+ srcs = ["runner.sh"],
+ args = args,
+ data = data,
+ size = size,
+ shard_count = shard_count,
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ "manual",
+ "local",
+ ],
+ )
+
+def blacklist_test(lang, blacklist_file):
+ """Test that a blacklist parses correctly."""
+ go_test(
+ name = lang + "_blacklist_test",
+ embed = [":runner"],
+ srcs = ["blacklist_test.go"],
+ args = ["--blacklist_file", "test/runtimes/" + blacklist_file],
+ data = [blacklist_file],
+ )
+
+def sh_test(**kwargs):
+ """Wraps the standard sh_test."""
+ native.sh_test(
+ **kwargs
+ )
diff --git a/test/runtimes/common/BUILD b/test/runtimes/common/BUILD
deleted file mode 100644
index 1b39606b8..000000000
--- a/test/runtimes/common/BUILD
+++ /dev/null
@@ -1,20 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "common",
- srcs = ["common.go"],
- importpath = "gvisor.dev/gvisor/test/runtimes/common",
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "common_test",
- size = "small",
- srcs = ["common_test.go"],
- deps = [
- ":common",
- "//runsc/test/testutil",
- ],
-)
diff --git a/test/runtimes/common/common.go b/test/runtimes/common/common.go
deleted file mode 100644
index 0ff87fa8b..000000000
--- a/test/runtimes/common/common.go
+++ /dev/null
@@ -1,114 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package common executes functions for proctor binaries.
-package common
-
-import (
- "flag"
- "fmt"
- "os"
- "path/filepath"
- "regexp"
-)
-
-var (
- list = flag.Bool("list", false, "list all available tests")
- test = flag.String("test", "", "run a single test from the list of available tests")
- version = flag.Bool("v", false, "print out the version of node that is installed")
-)
-
-// TestRunner is an interface to be implemented in each proctor binary.
-type TestRunner interface {
- // ListTests returns a string slice of tests available to run.
- ListTests() ([]string, error)
-
- // RunTest runs a single test.
- RunTest(test string) error
-}
-
-// LaunchFunc parses flags passed by a proctor binary and calls the requested behavior.
-func LaunchFunc(tr TestRunner) error {
- flag.Parse()
-
- if *list && *test != "" {
- flag.PrintDefaults()
- return fmt.Errorf("cannot specify 'list' and 'test' flags simultaneously")
- }
- if *list {
- tests, err := tr.ListTests()
- if err != nil {
- return fmt.Errorf("failed to list tests: %v", err)
- }
- for _, test := range tests {
- fmt.Println(test)
- }
- return nil
- }
- if *version {
- fmt.Println(os.Getenv("LANG_NAME"), "version:", os.Getenv("LANG_VER"), "is installed.")
- return nil
- }
- if *test != "" {
- if err := tr.RunTest(*test); err != nil {
- return fmt.Errorf("test %q failed to run: %v", *test, err)
- }
- return nil
- }
-
- if err := runAllTests(tr); err != nil {
- return fmt.Errorf("error running all tests: %v", err)
- }
- return nil
-}
-
-// Search uses filepath.Walk to perform a search of the disk for test files
-// and returns a string slice of tests.
-func Search(root string, testFilter *regexp.Regexp) ([]string, error) {
- var testSlice []string
-
- err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
- name := filepath.Base(path)
-
- if info.IsDir() || !testFilter.MatchString(name) {
- return nil
- }
-
- relPath, err := filepath.Rel(root, path)
- if err != nil {
- return err
- }
- testSlice = append(testSlice, relPath)
- return nil
- })
-
- if err != nil {
- return nil, fmt.Errorf("walking %q: %v", root, err)
- }
-
- return testSlice, nil
-}
-
-func runAllTests(tr TestRunner) error {
- tests, err := tr.ListTests()
- if err != nil {
- return fmt.Errorf("failed to list tests: %v", err)
- }
- for _, test := range tests {
- if err := tr.RunTest(test); err != nil {
- return fmt.Errorf("test %q failed to run: %v", test, err)
- }
- }
- return nil
-}
diff --git a/test/runtimes/go/BUILD b/test/runtimes/go/BUILD
deleted file mode 100644
index ce971ee9d..000000000
--- a/test/runtimes/go/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-go",
- srcs = ["proctor-go.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/go/Dockerfile b/test/runtimes/go/Dockerfile
deleted file mode 100644
index 1d5202b70..000000000
--- a/test/runtimes/go/Dockerfile
+++ /dev/null
@@ -1,34 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=1.12.5
-ENV LANG_NAME=Go
-
-RUN apt-get update && apt-get install -y \
- curl \
- gcc \
- git
-
-WORKDIR /root
-
-# Download Go 1.4 to use as a bootstrap for building Go from the source.
-RUN curl -o go1.4.linux-amd64.tar.gz https://dl.google.com/go/go1.4.linux-amd64.tar.gz
-RUN curl -LJO https://github.com/golang/go/archive/go${LANG_VER}.tar.gz
-RUN mkdir bootstr
-RUN tar -C bootstr -xzf go1.4.linux-amd64.tar.gz
-RUN tar -xzf go-go${LANG_VER}.tar.gz
-RUN mv go-go${LANG_VER} go
-
-ENV GOROOT=/root/go
-ENV GOROOT_BOOTSTRAP=/root/bootstr/go
-ENV LANG_DIR=${GOROOT}
-
-WORKDIR ${LANG_DIR}/src
-RUN ./make.bash
-# Pre-compile the tests for faster execution
-RUN ["/root/go/bin/go", "tool", "dist", "test", "-compile-only"]
-
-WORKDIR ${LANG_DIR}
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY go/proctor-go.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-go.go"]
diff --git a/test/runtimes/images/Dockerfile_go1.12 b/test/runtimes/images/Dockerfile_go1.12
new file mode 100644
index 000000000..ab9d6abf3
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_go1.12
@@ -0,0 +1,10 @@
+# Go is easy, since we already have everything we need to compile the proctor
+# binary and run the tests in the golang Docker image.
+FROM golang:1.12
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+# Pre-compile the tests so we don't need to do so in each test run.
+RUN ["go", "tool", "dist", "test", "-compile-only"]
+
+ENTRYPOINT ["/proctor", "--runtime=go"]
diff --git a/test/runtimes/images/Dockerfile_java11 b/test/runtimes/images/Dockerfile_java11
new file mode 100644
index 000000000..9b7c3d5a3
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_java11
@@ -0,0 +1,30 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ build-essential \
+ curl \
+ make \
+ openjdk-11-jdk \
+ unzip \
+ zip
+
+# Download the JDK test library.
+WORKDIR /root
+RUN set -ex \
+ && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk11/archive/76072a077ee1.tar.gz/test \
+ && tar -xzf /tmp/jdktests.tar.gz \
+ && mv jdk11-76072a077ee1/test test \
+ && rm -f /tmp/jdktests.tar.gz
+
+# Install jtreg and add to PATH.
+RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
+RUN tar -xzf jtreg.tar.gz
+ENV PATH="/root/jtreg/bin:$PATH"
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=java"]
diff --git a/test/runtimes/images/Dockerfile_nodejs12.4.0 b/test/runtimes/images/Dockerfile_nodejs12.4.0
new file mode 100644
index 000000000..26f68b487
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_nodejs12.4.0
@@ -0,0 +1,28 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ curl \
+ dumb-init \
+ g++ \
+ make \
+ python
+
+WORKDIR /root
+ARG VERSION=v12.4.0
+RUN curl -o node-${VERSION}.tar.gz https://nodejs.org/dist/${VERSION}/node-${VERSION}.tar.gz
+RUN tar -zxf node-${VERSION}.tar.gz
+
+WORKDIR /root/node-${VERSION}
+RUN ./configure
+RUN make
+RUN make test-build
+
+COPY --from=golang /proctor /proctor
+
+# Including dumb-init emulates the Linux "init" process, preventing the failure
+# of tests involving worker processes.
+ENTRYPOINT ["/usr/bin/dumb-init", "/proctor", "--runtime=nodejs"]
diff --git a/test/runtimes/images/Dockerfile_php7.3.6 b/test/runtimes/images/Dockerfile_php7.3.6
new file mode 100644
index 000000000..e6b4c6329
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_php7.3.6
@@ -0,0 +1,27 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ automake \
+ bison \
+ build-essential \
+ curl \
+ libtool \
+ libxml2-dev \
+ re2c
+
+WORKDIR /root
+ARG VERSION=7.3.6
+RUN curl -o php-${VERSION}.tar.gz https://www.php.net/distributions/php-${VERSION}.tar.gz
+RUN tar -zxf php-${VERSION}.tar.gz
+
+WORKDIR /root/php-${VERSION}
+RUN ./configure
+RUN make
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=php"]
diff --git a/test/runtimes/images/Dockerfile_python3.7.3 b/test/runtimes/images/Dockerfile_python3.7.3
new file mode 100644
index 000000000..905cd22d7
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_python3.7.3
@@ -0,0 +1,30 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+
+RUN apt-get update && apt-get install -y \
+ curl \
+ gcc \
+ libbz2-dev \
+ libffi-dev \
+ liblzma-dev \
+ libreadline-dev \
+ libssl-dev \
+ make \
+ zlib1g-dev
+
+# Use flags -LJO to follow the html redirect and download .tar.gz.
+WORKDIR /root
+ARG VERSION=3.7.3
+RUN curl -LJO https://github.com/python/cpython/archive/v${VERSION}.tar.gz
+RUN tar -zxf cpython-${VERSION}.tar.gz
+
+WORKDIR /root/cpython-${VERSION}
+RUN ./configure --with-pydebug
+RUN make -s -j2
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=python"]
diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD
new file mode 100644
index 000000000..09dc6c42f
--- /dev/null
+++ b/test/runtimes/images/proctor/BUILD
@@ -0,0 +1,26 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "proctor",
+ srcs = [
+ "go.go",
+ "java.go",
+ "nodejs.go",
+ "php.go",
+ "proctor.go",
+ "python.go",
+ ],
+ visibility = ["//test/runtimes/images:__subpackages__"],
+)
+
+go_test(
+ name = "proctor_test",
+ size = "small",
+ srcs = ["proctor_test.go"],
+ embed = [":proctor"],
+ deps = [
+ "//runsc/testutil",
+ ],
+)
diff --git a/test/runtimes/go/proctor-go.go b/test/runtimes/images/proctor/go.go
index 3eb24576e..3e2d5d8db 100644
--- a/test/runtimes/go/proctor-go.go
+++ b/test/runtimes/images/proctor/go.go
@@ -12,50 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-go is a utility that facilitates language testing for Go.
-
-// There are two types of Go tests: "Go tool tests" and "Go tests on disk".
-// "Go tool tests" are found and executed using `go tool dist test`.
-// "Go tests on disk" are found in the /test directory and are
-// executed using `go run run.go`.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"regexp"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
var (
- dir = os.Getenv("LANG_DIR")
- goBin = filepath.Join(dir, "bin/go")
- testDir = filepath.Join(dir, "test")
- testRegEx = regexp.MustCompile(`^.+\.go$`)
+ goTestRegEx = regexp.MustCompile(`^.+\.go$`)
// Directories with .dir contain helper files for tests.
// Exclude benchmarks and stress tests.
- dirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`)
+ goDirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`)
)
-type goRunner struct {
-}
+// Location of Go tests on disk.
+const goTestDir = "/usr/local/go/test"
-func main() {
- if err := common.LaunchFunc(goRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+// goRunner implements TestRunner for Go.
+//
+// There are two types of Go tests: "Go tool tests" and "Go tests on disk".
+// "Go tool tests" are found and executed using `go tool dist test`. "Go tests
+// on disk" are found in the /usr/local/go/test directory and are executed
+// using `go run run.go`.
+type goRunner struct{}
+
+var _ TestRunner = goRunner{}
-func (g goRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (goRunner) ListTests() ([]string, error) {
// Go tool dist test tests.
args := []string{"tool", "dist", "test", "-list"}
- cmd := exec.Command(filepath.Join(dir, "bin/go"), args...)
+ cmd := exec.Command("go", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -67,14 +59,14 @@ func (g goRunner) ListTests() ([]string, error) {
}
// Go tests on disk.
- diskSlice, err := common.Search(testDir, testRegEx)
+ diskSlice, err := search(goTestDir, goTestRegEx)
if err != nil {
return nil, err
}
// Remove items from /bench/, /stress/ and .dir files
diskFiltered := diskSlice[:0]
for _, file := range diskSlice {
- if !dirFilter.MatchString(file) {
+ if !goDirFilter.MatchString(file) {
diskFiltered = append(diskFiltered, file)
}
}
@@ -82,24 +74,17 @@ func (g goRunner) ListTests() ([]string, error) {
return append(toolSlice, diskFiltered...), nil
}
-func (g goRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (goRunner) TestCmd(test string) *exec.Cmd {
// Check if test exists on disk by searching for file of the same name.
// This will determine whether or not it is a Go test on disk.
if strings.HasSuffix(test, ".go") {
// Test has suffix ".go" which indicates a disk test, run it as such.
- cmd := exec.Command(goBin, "run", "run.go", "-v", "--", test)
- cmd.Dir = testDir
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run test: %v", err)
- }
- } else {
- // No ".go" suffix, run as a tool test.
- cmd := exec.Command(goBin, "tool", "dist", "test", "-run", test)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run test: %v", err)
- }
+ cmd := exec.Command("go", "run", "run.go", "-v", "--", test)
+ cmd.Dir = goTestDir
+ return cmd
}
- return nil
+
+ // No ".go" suffix, run as a tool test.
+ return exec.Command("go", "tool", "dist", "test", "-run", test)
}
diff --git a/test/runtimes/java/proctor-java.go b/test/runtimes/images/proctor/java.go
index 7f6a66f4f..8b362029d 100644
--- a/test/runtimes/java/proctor-java.go
+++ b/test/runtimes/images/proctor/java.go
@@ -12,40 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-java is a utility that facilitates language testing for Java.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"regexp"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
- hash = os.Getenv("LANG_HASH")
- jtreg = filepath.Join(dir, "jtreg/bin/jtreg")
- exclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`)
-)
+// Directories to exclude from tests.
+var javaExclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`)
-type javaRunner struct {
-}
+// Location of java tests.
+const javaTestDir = "/root/test/jdk"
-func main() {
- if err := common.LaunchFunc(javaRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+// javaRunner implements TestRunner for Java.
+type javaRunner struct{}
+
+var _ TestRunner = javaRunner{}
-func (j javaRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (javaRunner) ListTests() ([]string, error) {
args := []string{
- "-dir:/root/jdk11-" + hash + "/test/jdk",
+ "-dir:" + javaTestDir,
"-ignore:quiet",
"-a",
"-listtests",
@@ -54,7 +45,7 @@ func (j javaRunner) ListTests() ([]string, error) {
":jdk_sound",
":jdk_imageio",
}
- cmd := exec.Command(jtreg, args...)
+ cmd := exec.Command("jtreg", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -62,19 +53,19 @@ func (j javaRunner) ListTests() ([]string, error) {
}
var testSlice []string
for _, test := range strings.Split(string(out), "\n") {
- if !exclDirs.MatchString(test) {
+ if !javaExclDirs.MatchString(test) {
testSlice = append(testSlice, test)
}
}
return testSlice, nil
}
-func (j javaRunner) RunTest(test string) error {
- args := []string{"-noreport", "-dir:/root/jdk11-" + hash + "/test/jdk", test}
- cmd := exec.Command(jtreg, args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
+// TestCmd implements TestRunner.TestCmd.
+func (javaRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{
+ "-noreport",
+ "-dir:" + javaTestDir,
+ test,
}
- return nil
+ return exec.Command("jtreg", args...)
}
diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/images/proctor/nodejs.go
new file mode 100644
index 000000000..bd57db444
--- /dev/null
+++ b/test/runtimes/images/proctor/nodejs.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "os/exec"
+ "path/filepath"
+ "regexp"
+)
+
+var nodejsTestRegEx = regexp.MustCompile(`^test-[^-].+\.js$`)
+
+// Location of nodejs tests relative to working dir.
+const nodejsTestDir = "test"
+
+// nodejsRunner implements TestRunner for NodeJS.
+type nodejsRunner struct{}
+
+var _ TestRunner = nodejsRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (nodejsRunner) ListTests() ([]string, error) {
+ testSlice, err := search(nodejsTestDir, nodejsTestRegEx)
+ if err != nil {
+ return nil, err
+ }
+ return testSlice, nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (nodejsRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{filepath.Join("tools", "test.py"), test}
+ return exec.Command("/usr/bin/python", args...)
+}
diff --git a/test/runtimes/php/proctor-php.go b/test/runtimes/images/proctor/php.go
index e6c5fabdf..9115040e1 100644
--- a/test/runtimes/php/proctor-php.go
+++ b/test/runtimes/images/proctor/php.go
@@ -12,47 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-php is a utility that facilitates language testing for PHP.
package main
import (
- "fmt"
- "log"
- "os"
"os/exec"
"regexp"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
- testRegEx = regexp.MustCompile(`^.+\.phpt$`)
-)
+var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`)
-type phpRunner struct {
-}
+// phpRunner implements TestRunner for PHP.
+type phpRunner struct{}
-func main() {
- if err := common.LaunchFunc(phpRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+var _ TestRunner = phpRunner{}
-func (p phpRunner) ListTests() ([]string, error) {
- testSlice, err := common.Search(dir, testRegEx)
+// ListTests implements TestRunner.ListTests.
+func (phpRunner) ListTests() ([]string, error) {
+ testSlice, err := search(".", phpTestRegEx)
if err != nil {
return nil, err
}
return testSlice, nil
}
-func (p phpRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (phpRunner) TestCmd(test string) *exec.Cmd {
args := []string{"test", "TESTS=" + test}
- cmd := exec.Command("make", args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
+ return exec.Command("make", args...)
}
diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go
new file mode 100644
index 000000000..e6178e82b
--- /dev/null
+++ b/test/runtimes/images/proctor/proctor.go
@@ -0,0 +1,154 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary proctor runs the test for a particular runtime. It is meant to be
+// included in Docker images for all runtime tests.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "regexp"
+ "syscall"
+)
+
+// TestRunner is an interface that must be implemented for each runtime
+// integrated with proctor.
+type TestRunner interface {
+ // ListTests returns a string slice of tests available to run.
+ ListTests() ([]string, error)
+
+ // TestCmd returns an *exec.Cmd that will run the given test.
+ TestCmd(test string) *exec.Cmd
+}
+
+var (
+ runtime = flag.String("runtime", "", "name of runtime")
+ list = flag.Bool("list", false, "list all available tests")
+ test = flag.String("test", "", "run a single test from the list of available tests")
+ pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
+)
+
+func main() {
+ flag.Parse()
+
+ if *pause {
+ pauseAndReap()
+ panic("pauseAndReap should never return")
+ }
+
+ if *runtime == "" {
+ log.Fatalf("runtime flag must be provided")
+ }
+
+ tr, err := testRunnerForRuntime(*runtime)
+ if err != nil {
+ log.Fatalf("%v", err)
+ }
+
+ // List tests.
+ if *list {
+ tests, err := tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to list tests: %v", err)
+ }
+ for _, test := range tests {
+ fmt.Println(test)
+ }
+ return
+ }
+
+ // Run a single test.
+ if *test == "" {
+ log.Fatalf("test flag must be provided")
+ }
+ cmd := tr.TestCmd(*test)
+ cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
+ if err := cmd.Run(); err != nil {
+ log.Fatalf("FAIL: %v", err)
+ }
+}
+
+// testRunnerForRuntime returns a new TestRunner for the given runtime.
+func testRunnerForRuntime(runtime string) (TestRunner, error) {
+ switch runtime {
+ case "go":
+ return goRunner{}, nil
+ case "java":
+ return javaRunner{}, nil
+ case "nodejs":
+ return nodejsRunner{}, nil
+ case "php":
+ return phpRunner{}, nil
+ case "python":
+ return pythonRunner{}, nil
+ }
+ return nil, fmt.Errorf("invalid runtime %q", runtime)
+}
+
+// pauseAndReap is like init. It runs forever and reaps any children.
+func pauseAndReap() {
+ // Get notified of any new children.
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, syscall.SIGCHLD)
+
+ for {
+ if _, ok := <-ch; !ok {
+ // Channel closed. This should not happen.
+ panic("signal channel closed")
+ }
+
+ // Reap the child.
+ for {
+ if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 {
+ break
+ }
+ }
+ }
+}
+
+// search is a helper function to find tests in the given directory that match
+// the regex.
+func search(root string, testFilter *regexp.Regexp) ([]string, error) {
+ var testSlice []string
+
+ err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+
+ name := filepath.Base(path)
+
+ if info.IsDir() || !testFilter.MatchString(name) {
+ return nil
+ }
+
+ relPath, err := filepath.Rel(root, path)
+ if err != nil {
+ return err
+ }
+ testSlice = append(testSlice, relPath)
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("walking %q: %v", root, err)
+ }
+
+ return testSlice, nil
+}
diff --git a/test/runtimes/common/common_test.go b/test/runtimes/images/proctor/proctor_test.go
index 4fb1e482a..6bb61d142 100644
--- a/test/runtimes/common/common_test.go
+++ b/test/runtimes/images/proctor/proctor_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package common_test
+package main
import (
"io/ioutil"
@@ -23,8 +23,7 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/test/testutil"
- "gvisor.dev/gvisor/test/runtimes/common"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func touch(t *testing.T, name string) {
@@ -48,9 +47,9 @@ func TestSearchEmptyDir(t *testing.T) {
var want []string
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := common.Search(td, testFilter)
+ got, err := search(td, testFilter)
if err != nil {
- t.Errorf("Search error: %v", err)
+ t.Errorf("search error: %v", err)
}
if !reflect.DeepEqual(got, want) {
@@ -117,9 +116,9 @@ func TestSearch(t *testing.T) {
}
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := common.Search(td, testFilter)
+ got, err := search(td, testFilter)
if err != nil {
- t.Errorf("Search error: %v", err)
+ t.Errorf("search error: %v", err)
}
if !reflect.DeepEqual(got, want) {
diff --git a/test/runtimes/python/proctor-python.go b/test/runtimes/images/proctor/python.go
index 35e28a7df..b9e0fbe6f 100644
--- a/test/runtimes/python/proctor-python.go
+++ b/test/runtimes/images/proctor/python.go
@@ -12,36 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-python is a utility that facilitates language testing for Pyhton.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
-)
+// pythonRunner implements TestRunner for Python.
+type pythonRunner struct{}
-type pythonRunner struct {
-}
+var _ TestRunner = pythonRunner{}
-func main() {
- if err := common.LaunchFunc(pythonRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
-
-func (p pythonRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (pythonRunner) ListTests() ([]string, error) {
args := []string{"-m", "test", "--list-tests"}
- cmd := exec.Command(filepath.Join(dir, "python"), args...)
+ cmd := exec.Command("./python", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -54,12 +42,8 @@ func (p pythonRunner) ListTests() ([]string, error) {
return toolSlice, nil
}
-func (p pythonRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (pythonRunner) TestCmd(test string) *exec.Cmd {
args := []string{"-m", "test", test}
- cmd := exec.Command(filepath.Join(dir, "python"), args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
+ return exec.Command("./python", args...)
}
diff --git a/test/runtimes/java/BUILD b/test/runtimes/java/BUILD
deleted file mode 100644
index 8c39d39ec..000000000
--- a/test/runtimes/java/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-java",
- srcs = ["proctor-java.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/java/Dockerfile b/test/runtimes/java/Dockerfile
deleted file mode 100644
index b9132b575..000000000
--- a/test/runtimes/java/Dockerfile
+++ /dev/null
@@ -1,35 +0,0 @@
-FROM ubuntu:bionic
-# This hash is associated with a specific JDK release and needed for ensuring
-# the same version is downloaded every time.
-ENV LANG_HASH=76072a077ee1
-ENV LANG_VER=11
-ENV LANG_NAME=Java
-
-RUN apt-get update && apt-get install -y \
- autoconf \
- build-essential \
- curl \
- make \
- openjdk-${LANG_VER}-jdk \
- unzip \
- zip
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-# Download the JDK test library.
-RUN set -ex \
- && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk${LANG_VER}/archive/${LANG_HASH}.tar.gz/test \
- && tar -xzf /tmp/jdktests.tar.gz -C /root \
- && rm -f /tmp/jdktests.tar.gz
-
-RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
-RUN tar -xzf jtreg.tar.gz
-
-ENV LANG_DIR=/root
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY java/proctor-java.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-java.go"]
diff --git a/test/runtimes/nodejs/BUILD b/test/runtimes/nodejs/BUILD
deleted file mode 100644
index 0594c250b..000000000
--- a/test/runtimes/nodejs/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-nodejs",
- srcs = ["proctor-nodejs.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/nodejs/Dockerfile b/test/runtimes/nodejs/Dockerfile
deleted file mode 100644
index aba30d2ee..000000000
--- a/test/runtimes/nodejs/Dockerfile
+++ /dev/null
@@ -1,30 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=12.4.0
-ENV LANG_NAME=Node
-
-RUN apt-get update && apt-get install -y \
- curl \
- dumb-init \
- g++ \
- make \
- python
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-RUN curl -o node-v${LANG_VER}.tar.gz https://nodejs.org/dist/v${LANG_VER}/node-v${LANG_VER}.tar.gz
-RUN tar -zxf node-v${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/node-v${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure
-RUN make
-RUN make test-build
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY nodejs/proctor-nodejs.go ${LANG_DIR}
-
-# Including dumb-init emulates the Linux "init" process, preventing the failure
-# of tests involving worker processes.
-ENTRYPOINT ["/usr/bin/dumb-init", "/root/go/bin/go", "run", "proctor-nodejs.go"]
diff --git a/test/runtimes/nodejs/proctor-nodejs.go b/test/runtimes/nodejs/proctor-nodejs.go
deleted file mode 100644
index 0624f6a0d..000000000
--- a/test/runtimes/nodejs/proctor-nodejs.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Binary proctor-nodejs is a utility that facilitates language testing for NodeJS.
-package main
-
-import (
- "fmt"
- "log"
- "os"
- "os/exec"
- "path/filepath"
- "regexp"
-
- "gvisor.dev/gvisor/test/runtimes/common"
-)
-
-var (
- dir = os.Getenv("LANG_DIR")
- testDir = filepath.Join(dir, "test")
- testRegEx = regexp.MustCompile(`^test-[^-].+\.js$`)
-)
-
-type nodejsRunner struct {
-}
-
-func main() {
- if err := common.LaunchFunc(nodejsRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
-
-func (n nodejsRunner) ListTests() ([]string, error) {
- testSlice, err := common.Search(testDir, testRegEx)
- if err != nil {
- return nil, err
- }
- return testSlice, nil
-}
-
-func (n nodejsRunner) RunTest(test string) error {
- args := []string{filepath.Join(dir, "tools", "test.py"), test}
- cmd := exec.Command("/usr/bin/python", args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
-}
diff --git a/test/runtimes/php/BUILD b/test/runtimes/php/BUILD
deleted file mode 100644
index 31799b77a..000000000
--- a/test/runtimes/php/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-php",
- srcs = ["proctor-php.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/php/Dockerfile b/test/runtimes/php/Dockerfile
deleted file mode 100644
index 491ab902d..000000000
--- a/test/runtimes/php/Dockerfile
+++ /dev/null
@@ -1,30 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=7.3.6
-ENV LANG_NAME=PHP
-
-RUN apt-get update && apt-get install -y \
- autoconf \
- automake \
- bison \
- build-essential \
- curl \
- libtool \
- libxml2-dev \
- re2c
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-RUN curl -o php-${LANG_VER}.tar.gz https://www.php.net/distributions/php-${LANG_VER}.tar.gz
-RUN tar -zxf php-${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/php-${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure
-RUN make
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY php/proctor-php.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-php.go"]
diff --git a/test/runtimes/python/BUILD b/test/runtimes/python/BUILD
deleted file mode 100644
index 37fd6a0f2..000000000
--- a/test/runtimes/python/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-python",
- srcs = ["proctor-python.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/python/Dockerfile b/test/runtimes/python/Dockerfile
deleted file mode 100644
index 710daee43..000000000
--- a/test/runtimes/python/Dockerfile
+++ /dev/null
@@ -1,32 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=3.7.3
-ENV LANG_NAME=Python
-
-RUN apt-get update && apt-get install -y \
- curl \
- gcc \
- libbz2-dev \
- libffi-dev \
- liblzma-dev \
- libreadline-dev \
- libssl-dev \
- make \
- zlib1g-dev
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-# Use flags -LJO to follow the html redirect and download .tar.gz.
-RUN curl -LJO https://github.com/python/cpython/archive/v${LANG_VER}.tar.gz
-RUN tar -zxf cpython-${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/cpython-${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure --with-pydebug
-RUN make -s -j2
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY python/proctor-python.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-python.go"]
diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go
new file mode 100644
index 000000000..bec37c69d
--- /dev/null
+++ b/test/runtimes/runner.go
@@ -0,0 +1,199 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary runner runs the runtime tests in a Docker container.
+package main
+
+import (
+ "encoding/csv"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+var (
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ blacklistFile = flag.String("blacklist_file", "", "file containing blacklist of tests to exclude, in CSV format with fields: test name, bug id, comment")
+)
+
+// Wait time for each test to run.
+const timeout = 5 * time.Minute
+
+func main() {
+ flag.Parse()
+ if *lang == "" || *image == "" {
+ fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
+ os.Exit(1)
+ }
+
+ os.Exit(runTests())
+}
+
+// runTests is a helper that is called by main. It exists so that we can run
+// defered functions before exiting. It returns an exit code that should be
+// passed to os.Exit.
+func runTests() int {
+ // Get tests to blacklist.
+ blacklist, err := getBlacklist()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting blacklist: %s\n", err.Error())
+ return 1
+ }
+
+ // Create a single docker container that will be used for all tests.
+ d := dockerutil.MakeDocker("gvisor-" + *lang)
+ defer d.CleanUp()
+
+ // Get a slice of tests to run. This will also start a single Docker
+ // container that will be used to run each test. The final test will
+ // stop the Docker container.
+ tests, err := getTests(d, blacklist)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%s\n", err.Error())
+ return 1
+ }
+
+ m := testing.MainStart(testDeps{}, tests, nil, nil)
+ return m.Run()
+}
+
+// getTests returns a slice of tests to run, subject to the shard size and
+// index.
+func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) {
+ // Pull the image.
+ if err := dockerutil.Pull(*image); err != nil {
+ return nil, fmt.Errorf("docker pull %q failed: %v", *image, err)
+ }
+
+ // Run proctor with --pause flag to keep container alive forever.
+ if err := d.Run(*image, "--pause"); err != nil {
+ return nil, fmt.Errorf("docker run failed: %v", err)
+ }
+
+ // Get a list of all tests in the image.
+ list, err := d.Exec("/proctor", "--runtime", *lang, "--list")
+ if err != nil {
+ return nil, fmt.Errorf("docker exec failed: %v", err)
+ }
+
+ // Calculate a subset of tests to run corresponding to the current
+ // shard.
+ tests := strings.Fields(list)
+ sort.Strings(tests)
+ begin, end, err := testutil.TestBoundsForShard(len(tests))
+ if err != nil {
+ return nil, fmt.Errorf("TestsForShard() failed: %v", err)
+ }
+ log.Printf("Got bounds [%d:%d) for shard out of %d total tests", begin, end, len(tests))
+ tests = tests[begin:end]
+
+ var itests []testing.InternalTest
+ for _, tc := range tests {
+ // Capture tc in this scope.
+ tc := tc
+ itests = append(itests, testing.InternalTest{
+ Name: tc,
+ F: func(t *testing.T) {
+ // Is the test blacklisted?
+ if _, ok := blacklist[tc]; ok {
+ t.Skip("SKIP: blacklisted test %q", tc)
+ }
+
+ var (
+ now = time.Now()
+ done = make(chan struct{})
+ output string
+ err error
+ )
+
+ go func() {
+ fmt.Printf("RUNNING %s...\n", tc)
+ output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ if err == nil {
+ fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now))
+ return
+ }
+ t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output)
+ case <-time.After(timeout):
+ t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output)
+ }
+ },
+ })
+ }
+ return itests, nil
+}
+
+// getBlacklist reads the blacklist file and returns a set of test names to
+// exclude.
+func getBlacklist() (map[string]struct{}, error) {
+ blacklist := make(map[string]struct{})
+ if *blacklistFile == "" {
+ return blacklist, nil
+ }
+ file, err := testutil.FindFile(*blacklistFile)
+ if err != nil {
+ return nil, err
+ }
+ f, err := os.Open(file)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ r := csv.NewReader(f)
+
+ // First line is header. Skip it.
+ if _, err := r.Read(); err != nil {
+ return nil, err
+ }
+
+ for {
+ record, err := r.Read()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, err
+ }
+ blacklist[record[0]] = struct{}{}
+ }
+ return blacklist, nil
+}
+
+// testDeps implements testing.testDeps (an unexported interface), and is
+// required to use testing.MainStart.
+type testDeps struct{}
+
+func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil }
+func (f testDeps) StartCPUProfile(io.Writer) error { return nil }
+func (f testDeps) StopCPUProfile() {}
+func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
+func (f testDeps) ImportPath() string { return "" }
+func (f testDeps) StartTestLog(io.Writer) {}
+func (f testDeps) StopTestLog() error { return nil }
diff --git a/test/runtimes/runner.sh b/test/runtimes/runner.sh
new file mode 100755
index 000000000..a8d9a3460
--- /dev/null
+++ b/test/runtimes/runner.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euf -x -o pipefail
+
+echo -- "$@"
+
+# Create outputs dir if it does not exist.
+if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then
+ mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}"
+ chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}"
+fi
+
+# Update the timestamp on the shard status file. Bazel looks for this.
+touch "${TEST_SHARD_STATUS_FILE}"
+
+# Get location of runner binary.
+readonly runner=$(find "${TEST_SRCDIR}" -name runner)
+
+# Pass the arguments of this script directly to the runner.
+exec "${runner}" "$@"
+
diff --git a/test/runtimes/runtimes_test.go b/test/runtimes/runtimes_test.go
deleted file mode 100644
index 9421021a1..000000000
--- a/test/runtimes/runtimes_test.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package runtimes
-
-import (
- "strings"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/runsc/test/testutil"
-)
-
-// Wait time for each test to run.
-const timeout = 180 * time.Second
-
-// Helper function to execute the docker container associated with the
-// language passed. Captures the output of the list function and executes
-// each test individually, supplying any errors recieved.
-func testLang(t *testing.T, lang string) {
- t.Helper()
-
- img := "gcr.io/gvisor-presubmit/" + lang
- if err := testutil.Pull(img); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
-
- c := testutil.MakeDocker("gvisor-list")
-
- list, err := c.RunFg(img, "--list")
- if err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- c.CleanUp()
-
- tests := strings.Fields(list)
-
- for _, tc := range tests {
- tc := tc
- t.Run(tc, func(t *testing.T) {
- d := testutil.MakeDocker("gvisor-test")
- if err := d.Run(img, "--test", tc); err != nil {
- t.Fatalf("docker test %q failed to run: %v", tc, err)
- }
- defer d.CleanUp()
-
- status, err := d.Wait(timeout)
- if err != nil {
- t.Fatalf("docker test %q failed to wait: %v", tc, err)
- }
- if status == 0 {
- t.Logf("test %q passed", tc)
- return
- }
- logs, err := d.Logs()
- if err != nil {
- t.Fatalf("docker test %q failed to supply logs: %v", tc, err)
- }
- t.Errorf("test %q failed: %v", tc, logs)
- })
- }
-}
-
-func TestGo(t *testing.T) {
- testLang(t, "go")
-}
-
-func TestJava(t *testing.T) {
- testLang(t, "java")
-}
-
-func TestNodejs(t *testing.T) {
- testLang(t, "nodejs")
-}
-
-func TestPhp(t *testing.T) {
- testLang(t, "php")
-}
-
-func TestPython(t *testing.T) {
- testLang(t, "python")
-}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index ccae4925f..87ef87e07 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -185,7 +185,7 @@ syscall_test(
)
syscall_test(
- size = "medium",
+ size = "large",
shard_count = 5,
test = "//test/syscalls/linux:itimer_test",
)
@@ -305,6 +305,8 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test")
+syscall_test(test = "//test/syscalls/linux:proc_net_test")
+
syscall_test(
size = "medium",
test = "//test/syscalls/linux:pselect_test",
@@ -344,6 +346,11 @@ syscall_test(
)
syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:readahead_test",
+)
+
+syscall_test(
size = "medium",
shard_count = 5,
test = "//test/syscalls/linux:readv_socket_test",
@@ -494,6 +501,8 @@ syscall_test(
test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test",
)
+syscall_test(test = "//test/syscalls/linux:socket_ip_unbound_test")
+
syscall_test(test = "//test/syscalls/linux:socket_netdevice_test")
syscall_test(test = "//test/syscalls/linux:socket_netlink_route_test")
@@ -691,8 +700,13 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:proc_net_unix_test")
+syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test")
+
+syscall_test(test = "//test/syscalls/linux:proc_net_udp_test")
+
go_binary(
name = "syscall_test_runner",
+ testonly = 1,
srcs = ["syscall_test_runner.go"],
data = [
"//runsc",
@@ -700,7 +714,7 @@ go_binary(
deps = [
"//pkg/log",
"//runsc/specutils",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"//test/syscalls/gtest",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl
index a63eda81d..e94ef5602 100644
--- a/test/syscalls/build_defs.bzl
+++ b/test/syscalls/build_defs.bzl
@@ -4,7 +4,7 @@
# on the host (native) and runsc.
def syscall_test(
test,
- shard_count = 1,
+ shard_count = 5,
size = "small",
use_tmpfs = False,
add_overlay = False,
@@ -94,6 +94,7 @@ def _syscall_test(
# more stable.
if platform == "kvm":
tags += ["manual"]
+ tags += ["requires-kvm"]
args = [
# Arguments are passed directly to syscall_test_runner binary.
@@ -122,3 +123,6 @@ def sh_test(**kwargs):
native.sh_test(
**kwargs
)
+
+def select_for_linux(for_linux, for_others = []):
+ return for_linux
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index ca4344139..84a8eb76c 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1,3 +1,6 @@
+load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
+load("//test/syscalls:build_defs.bzl", "select_for_linux")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
@@ -108,20 +111,27 @@ cc_library(
cc_library(
name = "socket_test_util",
testonly = 1,
- srcs = ["socket_test_util.cc"],
+ srcs = [
+ "socket_test_util.cc",
+ ] + select_for_linux(
+ [
+ "socket_test_util_impl.cc",
+ ],
+ ),
hdrs = ["socket_test_util.h"],
deps = [
+ "@com_google_googletest//:gtest",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
"//test/util:file_descriptor",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
- ],
+ ] + select_for_linux([
+ ]),
)
cc_library(
@@ -256,12 +266,15 @@ cc_binary(
],
linkstatic = 1,
deps = [
- # The heap check doesn't handle mremap properly.
+ # The heapchecker doesn't recognize that io_destroy munmaps.
"@com_google_googletest//:gtest",
"@com_google_absl//absl/strings",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:memory_util",
"//test/util:posix_error",
+ "//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -320,6 +333,7 @@ cc_binary(
linkstatic = 1,
deps = [
":socket_test_util",
+ "//test/util:file_descriptor",
"//test/util:test_main",
"//test/util:test_util",
"@com_google_googletest//:gtest",
@@ -381,6 +395,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest",
],
@@ -399,6 +414,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -715,6 +731,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:timer_util",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
@@ -963,6 +980,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
@@ -983,6 +1001,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -1204,6 +1223,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -1411,6 +1431,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -1427,6 +1448,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -1530,6 +1552,7 @@ cc_binary(
srcs = ["proc_net.cc"],
linkstatic = 1,
deps = [
+ "//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
"//test/util:test_main",
@@ -1609,6 +1632,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -1713,6 +1737,20 @@ cc_binary(
)
cc_binary(
+ name = "readahead_test",
+ testonly = 1,
+ srcs = ["readahead.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "readv_test",
testonly = 1,
srcs = [
@@ -1867,11 +1905,14 @@ cc_binary(
srcs = ["sendfile.cc"],
linkstatic = 1,
deps = [
+ "//test/util:eventfd_util",
"//test/util:file_descriptor",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ "//test/util:thread_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
@@ -1905,6 +1946,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
@@ -1957,6 +1999,24 @@ cc_binary(
)
cc_binary(
+ name = "signalfd_test",
+ testonly = 1,
+ srcs = ["signalfd.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "sigprocmask_test",
testonly = 1,
srcs = ["sigprocmask.cc"],
@@ -1979,6 +2039,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -2405,6 +2466,63 @@ cc_binary(
)
cc_binary(
+ name = "socket_bind_to_device_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_sequence_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_sequence.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_distribution_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_distribution.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "socket_ip_udp_loopback_non_blocking_test",
testonly = 1,
srcs = [
@@ -2437,6 +2555,22 @@ cc_binary(
)
cc_binary(
+ name = "socket_ip_unbound_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_unbound.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "socket_domain_test",
testonly = 1,
srcs = [
@@ -2681,6 +2815,23 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "socket_bind_to_device_util",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_util.cc",
+ ],
+ hdrs = [
+ "socket_bind_to_device_util.h",
+ ],
+ deps = [
+ "//test/util:test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
cc_binary(
name = "socket_stream_local_test",
testonly = 1,
@@ -3104,8 +3255,6 @@ cc_binary(
testonly = 1,
srcs = ["timers.cc"],
linkstatic = 1,
- # FIXME(b/136599201)
- tags = ["flaky"],
deps = [
"//test/util:cleanup",
"//test/util:logging",
@@ -3114,6 +3263,7 @@ cc_binary(
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -3193,6 +3343,8 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:uid_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -3284,6 +3436,7 @@ cc_binary(
"//test/util:multiprocess_util",
"//test/util:test_util",
"//test/util:time_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -3353,6 +3506,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
@@ -3463,3 +3617,18 @@ cc_binary(
"@com_google_googletest//:gtest",
],
)
+
+cc_binary(
+ name = "proc_net_udp_test",
+ testonly = 1,
+ srcs = ["proc_net_udp.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/test/syscalls/linux/affinity.cc b/test/syscalls/linux/affinity.cc
index f2d8375b6..128364c34 100644
--- a/test/syscalls/linux/affinity.cc
+++ b/test/syscalls/linux/affinity.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <sched.h>
+#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc
index 68dc05417..b27d4e10a 100644
--- a/test/syscalls/linux/aio.cc
+++ b/test/syscalls/linux/aio.cc
@@ -14,31 +14,57 @@
#include <fcntl.h>
#include <linux/aio_abi.h>
-#include <string.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
+#include <algorithm>
+#include <string>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/file_base.h"
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+using ::testing::_;
+
namespace gvisor {
namespace testing {
namespace {
+// Returns the size of the VMA containing the given address.
+PosixErrorOr<size_t> VmaSizeAt(uintptr_t addr) {
+ ASSIGN_OR_RETURN_ERRNO(std::string proc_self_maps,
+ GetContents("/proc/self/maps"));
+ ASSIGN_OR_RETURN_ERRNO(auto entries, ParseProcMaps(proc_self_maps));
+ // Use binary search to find the first VMA that might contain addr.
+ ProcMapsEntry target = {};
+ target.end = addr;
+ auto it =
+ std::upper_bound(entries.begin(), entries.end(), target,
+ [](const ProcMapsEntry& x, const ProcMapsEntry& y) {
+ return x.end < y.end;
+ });
+ // Check that it actually contains addr.
+ if (it == entries.end() || addr < it->start) {
+ return PosixError(ENOENT, absl::StrCat("no VMA contains address ", addr));
+ }
+ return it->end - it->start;
+}
+
constexpr char kData[] = "hello world!";
int SubmitCtx(aio_context_t ctx, long nr, struct iocb** iocbpp) {
return syscall(__NR_io_submit, ctx, nr, iocbpp);
}
-} // namespace
-
class AIOTest : public FileTest {
public:
AIOTest() : ctx_(0) {}
@@ -124,10 +150,10 @@ TEST_F(AIOTest, BasicWrite) {
EXPECT_EQ(events[0].res, strlen(kData));
// Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
}
TEST_F(AIOTest, BadWrite) {
@@ -220,38 +246,25 @@ TEST_F(AIOTest, CloneVm) {
TEST_F(AIOTest, Mremap) {
// Setup a context that is 128 entries deep.
ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
struct iocb cb = CreateCallback();
struct iocb* cbs[1] = {&cb};
// Reserve address space for the mremap target so we have something safe to
// map over.
- //
- // N.B. We reserve 2 pages because we'll attempt to remap to 2 pages below.
- // That should fail with EFAULT, but will fail with EINVAL if this mmap
- // returns the page immediately below ctx_, as
- // [new_address, new_address+2*kPageSize) overlaps [ctx_, ctx_+kPageSize).
- void* new_address = mmap(nullptr, 2 * kPageSize, PROT_READ,
- MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<intptr_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, 2 * kPageSize), SyscallSucceeds());
- });
-
- // Test that remapping to a larger address fails.
- void* res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, 2 * kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(reinterpret_cast<intptr_t>(res), SyscallFailsWithErrno(EFAULT));
+ Mapping dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
// Remap context 'handle' to a different address.
- res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(
- reinterpret_cast<intptr_t>(res),
- SyscallSucceedsWithValue(reinterpret_cast<intptr_t>(new_address)));
- mmap_cleanup.Release();
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
aio_context_t old_ctx = ctx_;
- ctx_ = reinterpret_cast<aio_context_t>(new_address);
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ // io_destroy() will unmap dst now.
+ dst.release();
// Check that submitting the request with the old 'ctx_' fails.
ASSERT_THAT(SubmitCtx(old_ctx, 1, cbs), SyscallFailsWithErrno(EINVAL));
@@ -260,18 +273,12 @@ TEST_F(AIOTest, Mremap) {
ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
// Remap again.
- new_address =
- mmap(nullptr, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup2 = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds());
- });
- res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(reinterpret_cast<int64_t>(res),
- SyscallSucceedsWithValue(reinterpret_cast<int64_t>(new_address)));
- mmap_cleanup2.Release();
- ctx_ = reinterpret_cast<aio_context_t>(new_address);
+ dst = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ dst.release();
// Get the reply with yet another 'ctx_' and verify it.
struct io_event events[1];
@@ -281,51 +288,33 @@ TEST_F(AIOTest, Mremap) {
EXPECT_EQ(events[0].res, strlen(kData));
// Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
}
-// Tests that AIO context can be replaced with a different mapping at the same
-// address and continue working. Don't ask why, but Linux allows it.
-TEST_F(AIOTest, MremapOver) {
+// Tests that AIO context cannot be expanded with mremap.
+TEST_F(AIOTest, MremapExpansion) {
// Setup a context that is 128 entries deep.
ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
- struct iocb cb = CreateCallback();
- struct iocb* cbs[1] = {&cb};
-
- ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
-
- // Allocate a new VMA, copy 'ctx_' content over, and remap it on top
- // of 'ctx_'.
- void* new_address = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
- MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds());
- });
-
- memcpy(new_address, reinterpret_cast<void*>(ctx_), kPageSize);
- void* res =
- mremap(new_address, kPageSize, kPageSize, MREMAP_FIXED | MREMAP_MAYMOVE,
- reinterpret_cast<void*>(ctx_));
- ASSERT_THAT(reinterpret_cast<int64_t>(res), SyscallSucceedsWithValue(ctx_));
- mmap_cleanup.Release();
-
- // Everything continues to work just fine.
- struct io_event events[1];
- ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1));
- EXPECT_EQ(events[0].data, 0x123);
- EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb));
- EXPECT_EQ(events[0].res, strlen(kData));
-
- // Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ // Reserve address space for the mremap target so we have something safe to
+ // map over.
+ Mapping dst = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(ctx_size + kPageSize, PROT_NONE, MAP_PRIVATE));
+
+ // Test that remapping to a larger address range fails.
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ PosixErrorIs(EFAULT, _));
+
+ // mm/mremap.c:sys_mremap() => mremap_to() does do_munmap() of the destination
+ // before it hits the VM_DONTEXPAND check in vma_to_resize(), so we should no
+ // longer munmap it (another thread may have created a mapping there).
+ dst.release();
}
// Tests that AIO calls fail if context's address is inaccessible.
@@ -429,5 +418,7 @@ TEST_P(AIOVectorizedParamTest, BadIOVecs) {
INSTANTIATE_TEST_SUITE_P(BadIOVecs, AIOVectorizedParamTest,
::testing::Values(IOCB_CMD_PREADV, IOCB_CMD_PWRITEV));
+} // namespace
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/base_poll_test.h b/test/syscalls/linux/base_poll_test.h
index 088831f9f..0d4a6701e 100644
--- a/test/syscalls/linux/base_poll_test.h
+++ b/test/syscalls/linux/base_poll_test.h
@@ -56,7 +56,7 @@ class TimerThread {
private:
mutable absl::Mutex mu_;
- bool cancel_ GUARDED_BY(mu_) = false;
+ bool cancel_ ABSL_GUARDED_BY(mu_) = false;
// Must be last to ensure that the destructor for the thread is run before
// any other member of the object is destroyed.
diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc
index 2e82f0b3a..7a28b674d 100644
--- a/test/syscalls/linux/chown.cc
+++ b/test/syscalls/linux/chown.cc
@@ -16,10 +16,12 @@
#include <grp.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/synchronization/notification.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -29,9 +31,9 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid1, 65534, "first scratch UID");
-DEFINE_int32(scratch_uid2, 65533, "second scratch UID");
-DEFINE_int32(scratch_gid, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
namespace gvisor {
namespace testing {
@@ -100,10 +102,12 @@ TEST_P(ChownParamTest, ChownFilePermissionDenied) {
// Change EUID and EGID.
//
// See note about POSIX below.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid1, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid1), -1),
+ SyscallSucceeds());
EXPECT_THAT(GetParam()(file.path(), geteuid(), getegid()),
PosixErrorIs(EPERM, ::testing::ContainsRegex("chown")));
@@ -125,8 +129,9 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
// setresuid syscall. However, we want this thread to have its own set of
// credentials different from the parent process, so we use the raw
// syscall.
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid2, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid2), -1),
+ SyscallSucceeds());
// Create file and immediately close it.
FileDescriptor fd =
@@ -143,12 +148,13 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
fileCreated.WaitForNotification();
// Set file's owners to someone different.
- EXPECT_NO_ERRNO(GetParam()(filename, FLAGS_scratch_uid1, FLAGS_scratch_gid));
+ EXPECT_NO_ERRNO(GetParam()(filename, absl::GetFlag(FLAGS_scratch_uid1),
+ absl::GetFlag(FLAGS_scratch_gid)));
struct stat s;
EXPECT_THAT(stat(filename.c_str(), &s), SyscallSucceeds());
- EXPECT_EQ(s.st_uid, FLAGS_scratch_uid1);
- EXPECT_EQ(s.st_gid, FLAGS_scratch_gid);
+ EXPECT_EQ(s.st_uid, absl::GetFlag(FLAGS_scratch_uid1));
+ EXPECT_EQ(s.st_gid, absl::GetFlag(FLAGS_scratch_gid));
fileChowned.Notify();
}
diff --git a/test/syscalls/linux/clock_nanosleep.cc b/test/syscalls/linux/clock_nanosleep.cc
index 52a69d230..b55cddc52 100644
--- a/test/syscalls/linux/clock_nanosleep.cc
+++ b/test/syscalls/linux/clock_nanosleep.cc
@@ -43,7 +43,7 @@ int sys_clock_nanosleep(clockid_t clkid, int flags,
PosixErrorOr<absl::Time> GetTime(clockid_t clk) {
struct timespec ts = {};
- int rc = clock_gettime(clk, &ts);
+ const int rc = clock_gettime(clk, &ts);
MaybeSave();
if (rc < 0) {
return PosixError(errno, "clock_gettime");
@@ -67,31 +67,32 @@ TEST_P(WallClockNanosleepTest, InvalidValues) {
}
TEST_P(WallClockNanosleepTest, SleepOneSecond) {
- absl::Duration const duration = absl::Seconds(1);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(1);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ EXPECT_THAT(
+ RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_GE(after - before, duration);
+ EXPECT_GE(after - before, kSleepDuration);
}
TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
- absl::Duration const duration = absl::Seconds(60);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(60);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Install no-op signal handler for SIGALRM.
struct sigaction sa = {};
sigfillset(&sa.sa_mask);
sa.sa_handler = +[](int signo) {};
- auto const cleanup_sa =
+ const auto cleanup_sa =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
// Measure time since setting the alarm, since the alarm will interrupt the
// sleep and hence determine how long we sleep.
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
// Set an alarm to go off while sleeping.
struct itimerval timer = {};
@@ -99,26 +100,51 @@ TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
timer.it_value.tv_usec = 0;
timer.it_interval.tv_sec = 1;
timer.it_interval.tv_usec = 0;
- auto const cleanup =
+ const auto cleanup =
ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, timer));
- EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &dur, &dur),
+ EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &duration, &duration),
SyscallFailsWithErrno(EINTR));
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Duration const remaining = absl::DurationFromTimespec(dur);
- EXPECT_GE(after - before + remaining, duration);
+ // Remaining time updated.
+ const absl::Duration remaining = absl::DurationFromTimespec(duration);
+ EXPECT_GE(after - before + remaining, kSleepDuration);
+}
+
+// Remaining time is *not* updated if nanosleep completes uninterrupted.
+TEST_P(WallClockNanosleepTest, UninterruptedNanosleep) {
+ constexpr absl::Duration kSleepDuration = absl::Milliseconds(10);
+ const struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ while (true) {
+ constexpr int kRemainingMagic = 42;
+ struct timespec remaining;
+ remaining.tv_sec = kRemainingMagic;
+ remaining.tv_nsec = kRemainingMagic;
+
+ int ret = sys_clock_nanosleep(GetParam(), 0, &duration, &remaining);
+ if (ret == EINTR) {
+ // Retry from beginning. We want a single uninterrupted call.
+ continue;
+ }
+
+ EXPECT_THAT(ret, SyscallSucceeds());
+ EXPECT_EQ(remaining.tv_sec, kRemainingMagic);
+ EXPECT_EQ(remaining.tv_nsec, kRemainingMagic);
+ break;
+ }
}
TEST_P(WallClockNanosleepTest, SleepUntil) {
- absl::Time const now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Time const until = now + absl::Seconds(2);
- struct timespec ts = absl::ToTimespec(until);
+ const absl::Time now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time until = now + absl::Seconds(2);
+ const struct timespec ts = absl::ToTimespec(until);
EXPECT_THAT(
RetryEINTR(sys_clock_nanosleep)(GetParam(), TIMER_ABSTIME, &ts, nullptr),
SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
EXPECT_GE(after, until);
}
@@ -127,8 +153,8 @@ INSTANTIATE_TEST_SUITE_P(Sleepers, WallClockNanosleepTest,
::testing::Values(CLOCK_REALTIME, CLOCK_MONOTONIC));
TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
- absl::Duration const kDuration = absl::Seconds(5);
- struct timespec dur = absl::ToTimespec(kDuration);
+ const absl::Duration kSleepDuration = absl::Seconds(5);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Ensure that CLOCK_PROCESS_CPUTIME_ID advances.
std::atomic<bool> done(false);
@@ -136,16 +162,16 @@ TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
while (!done.load()) {
}
});
- auto const cleanup_done = Cleanup([&] { done.store(true); });
+ const auto cleanup_done = Cleanup([&] { done.store(true); });
- absl::Time const before =
+ const absl::Time before =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_THAT(
- RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after =
+ EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0,
+ &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_GE(after - before, kDuration);
+ EXPECT_GE(after - before, kSleepDuration);
}
} // namespace
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index 4c7c95321..4947271ba 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -140,57 +140,57 @@ void CheckOutput(const std::string& filename, const ExecveArray& argv,
EXPECT_TRUE(absl::StrContains(output, expect_stderr)) << output;
}
-TEST(ExecDeathTest, EmptyPath) {
+TEST(ExecTest, EmptyPath) {
int execve_errno;
ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec("", {}, {}, nullptr, &execve_errno));
EXPECT_EQ(execve_errno, ENOENT);
}
-TEST(ExecDeathTest, Basic) {
+TEST(ExecTest, Basic) {
CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {},
ArgEnvExitStatus(0, 0),
absl::StrCat(WorkloadPath(kBasicWorkload), "\n"));
}
-TEST(ExecDeathTest, OneArg) {
+TEST(ExecTest, OneArg) {
CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"},
{}, ArgEnvExitStatus(1, 0),
absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
}
-TEST(ExecDeathTest, FiveArg) {
+TEST(ExecTest, FiveArg) {
CheckOutput(WorkloadPath(kBasicWorkload),
{WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {},
ArgEnvExitStatus(5, 0),
absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
}
-TEST(ExecDeathTest, OneEnv) {
+TEST(ExecTest, OneEnv) {
CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)},
{"1"}, ArgEnvExitStatus(0, 1),
absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
}
-TEST(ExecDeathTest, FiveEnv) {
+TEST(ExecTest, FiveEnv) {
CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)},
{"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5),
absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
}
-TEST(ExecDeathTest, OneArgOneEnv) {
+TEST(ExecTest, OneArgOneEnv) {
CheckOutput(WorkloadPath(kBasicWorkload),
{WorkloadPath(kBasicWorkload), "arg"}, {"env"},
ArgEnvExitStatus(1, 1),
absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n"));
}
-TEST(ExecDeathTest, InterpreterScript) {
+TEST(ExecTest, InterpreterScript) {
CheckOutput(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {},
ArgEnvExitStatus(25, 0), "");
}
// Everything after the path in the interpreter script is a single argument.
-TEST(ExecDeathTest, InterpreterScriptArgSplit) {
+TEST(ExecTest, InterpreterScriptArgSplit) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -204,7 +204,7 @@ TEST(ExecDeathTest, InterpreterScriptArgSplit) {
}
// Original argv[0] is replaced with the script path.
-TEST(ExecDeathTest, InterpreterScriptArgvZero) {
+TEST(ExecTest, InterpreterScriptArgvZero) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -218,7 +218,7 @@ TEST(ExecDeathTest, InterpreterScriptArgvZero) {
// Original argv[0] is replaced with the script path, exactly as passed to
// execve.
-TEST(ExecDeathTest, InterpreterScriptArgvZeroRelative) {
+TEST(ExecTest, InterpreterScriptArgvZeroRelative) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -235,7 +235,7 @@ TEST(ExecDeathTest, InterpreterScriptArgvZeroRelative) {
}
// argv[0] is added as the script path, even if there was none.
-TEST(ExecDeathTest, InterpreterScriptArgvZeroAdded) {
+TEST(ExecTest, InterpreterScriptArgvZeroAdded) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -248,7 +248,7 @@ TEST(ExecDeathTest, InterpreterScriptArgvZeroAdded) {
}
// A NUL byte in the script line ends parsing.
-TEST(ExecDeathTest, InterpreterScriptArgNUL) {
+TEST(ExecTest, InterpreterScriptArgNUL) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -263,7 +263,7 @@ TEST(ExecDeathTest, InterpreterScriptArgNUL) {
}
// Trailing whitespace following interpreter path is ignored.
-TEST(ExecDeathTest, InterpreterScriptTrailingWhitespace) {
+TEST(ExecTest, InterpreterScriptTrailingWhitespace) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -276,7 +276,7 @@ TEST(ExecDeathTest, InterpreterScriptTrailingWhitespace) {
}
// Multiple whitespace characters between interpreter and arg allowed.
-TEST(ExecDeathTest, InterpreterScriptArgWhitespace) {
+TEST(ExecTest, InterpreterScriptArgWhitespace) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -288,7 +288,7 @@ TEST(ExecDeathTest, InterpreterScriptArgWhitespace) {
absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n"));
}
-TEST(ExecDeathTest, InterpreterScriptNoPath) {
+TEST(ExecTest, InterpreterScriptNoPath) {
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "#!", 0755));
@@ -299,7 +299,7 @@ TEST(ExecDeathTest, InterpreterScriptNoPath) {
}
// AT_EXECFN is the path passed to execve.
-TEST(ExecDeathTest, ExecFn) {
+TEST(ExecTest, ExecFn) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload)));
@@ -318,14 +318,14 @@ TEST(ExecDeathTest, ExecFn) {
absl::StrCat(script_relative, "\n"));
}
-TEST(ExecDeathTest, ExecName) {
+TEST(ExecTest, ExecName) {
std::string path = WorkloadPath(kStateWorkload);
CheckOutput(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0),
absl::StrCat(Basename(path).substr(0, 15), "\n"));
}
-TEST(ExecDeathTest, ExecNameScript) {
+TEST(ExecTest, ExecNameScript) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload)));
@@ -341,14 +341,14 @@ TEST(ExecDeathTest, ExecNameScript) {
}
// execve may be called by a multithreaded process.
-TEST(ExecDeathTest, WithSiblingThread) {
+TEST(ExecTest, WithSiblingThread) {
CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {},
W_EXITCODE(42, 0), "");
}
// execve may be called from a thread other than the leader of a multithreaded
// process.
-TEST(ExecDeathTest, FromSiblingThread) {
+TEST(ExecTest, FromSiblingThread) {
CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {},
W_EXITCODE(42, 0), "");
}
@@ -376,7 +376,7 @@ void SignalHandler(int signo) {
// Signal handlers are reset on execve(2), unless they have default or ignored
// disposition.
-TEST(ExecStateDeathTest, HandlerReset) {
+TEST(ExecStateTest, HandlerReset) {
struct sigaction sa;
sa.sa_handler = SignalHandler;
ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
@@ -392,7 +392,7 @@ TEST(ExecStateDeathTest, HandlerReset) {
}
// Ignored signal dispositions are not reset.
-TEST(ExecStateDeathTest, IgnorePreserved) {
+TEST(ExecStateTest, IgnorePreserved) {
struct sigaction sa;
sa.sa_handler = SIG_IGN;
ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
@@ -408,7 +408,7 @@ TEST(ExecStateDeathTest, IgnorePreserved) {
}
// Signal masks are not reset on exec
-TEST(ExecStateDeathTest, SignalMask) {
+TEST(ExecStateTest, SignalMask) {
sigset_t s;
sigemptyset(&s);
sigaddset(&s, SIGUSR1);
@@ -425,7 +425,7 @@ TEST(ExecStateDeathTest, SignalMask) {
// itimers persist across execve.
// N.B. Timers created with timer_create(2) should not be preserved!
-TEST(ExecStateDeathTest, ItimerPreserved) {
+TEST(ExecStateTest, ItimerPreserved) {
// The fork in ForkAndExec clears itimers, so only set them up after fork.
auto setup_itimer = [] {
// Ignore SIGALRM, as we don't actually care about timer
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 91b55015c..0a3931e5a 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -401,12 +401,17 @@ TEST(ElfTest, DataSegment) {
})));
}
-// Additonal pages beyond filesz are always RW.
+// Additonal pages beyond filesz honor (only) execute protections.
//
-// N.B. Linux uses set_brk -> vm_brk to additional pages beyond filesz (even
-// though start_brk itself will always be beyond memsz). As a result, the
-// segment permissions don't apply; the mapping is always RW.
+// N.B. Linux changed this in 4.11 (16e72e9b30986 "powerpc: do not make the
+// entire heap executable"). Previously, extra pages were always RW.
TEST(ElfTest, ExtraMemPages) {
+ // gVisor has the newer behavior.
+ if (!IsRunningOnGvisor()) {
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 11));
+ }
+
ElfBinary<64> elf = StandardElf();
// Create a standard ELF, but extend to 1.5 pages. The second page will be the
@@ -415,7 +420,7 @@ TEST(ElfTest, ExtraMemPages) {
decltype(elf)::ElfPhdr phdr = {};
phdr.p_type = PT_LOAD;
- // RWX segment. The extra anon page will be RW anyways.
+ // RWX segment. The extra anon page will also be RWX.
//
// N.B. Linux uses clear_user to clear the end of the file-mapped page, which
// respects the mapping protections. Thus if we map this RO with memsz >
@@ -454,7 +459,7 @@ TEST(ElfTest, ExtraMemPages) {
{0x41000, 0x42000, true, true, true, true, kPageSize, 0, 0, 0,
file.path().c_str()},
// extra page from anon.
- {0x42000, 0x43000, true, true, false, true, 0, 0, 0, 0, ""},
+ {0x42000, 0x43000, true, true, true, true, 0, 0, 0, 0, ""},
})));
}
@@ -469,7 +474,7 @@ TEST(ElfTest, AnonOnlySegment) {
phdr.p_offset = 0;
phdr.p_vaddr = 0x41000;
phdr.p_filesz = 0;
- phdr.p_memsz = kPageSize - 0xe8;
+ phdr.p_memsz = kPageSize;
elf.phdrs.push_back(phdr);
elf.UpdateOffsets();
@@ -854,6 +859,11 @@ TEST(ElfTest, ELFInterpreter) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // N.B. Since Linux 4.10 (0036d1f7eb95b "binfmt_elf: fix calculations for bss
+ // padding"), Linux unconditionally zeroes the remainder of the highest mapped
+ // page in an interpreter, failing if the protections don't allow write. Thus
+ // we must mark this writeable.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -903,15 +913,15 @@ TEST(ElfTest, ELFInterpreter) {
const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
- EXPECT_THAT(child,
- ContainsMappings(std::vector<ProcMapsEntry>({
- // Main binary
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
- binary_file.path().c_str()},
- // Interpreter
- {interp_load_addr, interp_load_addr + 0x1000, true, false,
- true, true, 0, 0, 0, 0, interpreter_file.path().c_str()},
- })));
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
}
// Test parameter to ElfInterpterStaticTest cases. The first item is a suffix to
@@ -928,6 +938,8 @@ TEST_P(ElfInterpreterStaticTest, Test) {
const int expected_errno = std::get<1>(GetParam());
ElfBinary<64> interpreter = StandardElf();
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.UpdateOffsets();
TempPath interpreter_file =
ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
@@ -957,7 +969,7 @@ TEST_P(ElfInterpreterStaticTest, Test) {
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
// Interpreter.
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0,
+ {0x40000, 0x41000, true, true, true, true, 0, 0, 0,
0, interpreter_file.path().c_str()},
})));
}
@@ -1035,6 +1047,8 @@ TEST(ElfTest, ELFInterpreterRelative) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -1073,15 +1087,15 @@ TEST(ElfTest, ELFInterpreterRelative) {
const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
- EXPECT_THAT(child,
- ContainsMappings(std::vector<ProcMapsEntry>({
- // Main binary
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
- binary_file.path().c_str()},
- // Interpreter
- {interp_load_addr, interp_load_addr + 0x1000, true, false,
- true, true, 0, 0, 0, 0, interpreter_file.path().c_str()},
- })));
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
}
// ELF interpreter architecture doesn't match the binary.
@@ -1095,6 +1109,8 @@ TEST(ElfTest, ELFInterpreterWrongArch) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -1174,6 +1190,8 @@ TEST(ElfTest, ElfInterpreterNoExecute) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 2f8e7c9dd..8a45be12a 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -17,9 +17,12 @@
#include <syscall.h>
#include <unistd.h>
+#include <string>
+
#include "gtest/gtest.h"
#include "absl/base/macros.h"
#include "absl/base/port.h"
+#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
@@ -33,18 +36,19 @@
#include "test/util/test_util.h"
#include "test/util/timer_util.h"
-DEFINE_string(child_setlock_on, "",
- "Contains the path to try to set a file lock on.");
-DEFINE_bool(child_setlock_write, false,
- "Whether to set a writable lock (otherwise readable)");
-DEFINE_bool(blocking, false,
- "Whether to set a blocking lock (otherwise non-blocking).");
-DEFINE_bool(retry_eintr, false, "Whether to retry in the subprocess on EINTR.");
-DEFINE_uint64(child_setlock_start, 0, "The value of struct flock start");
-DEFINE_uint64(child_setlock_len, 0, "The value of struct flock len");
-DEFINE_int32(socket_fd, -1,
- "A socket to use for communicating more state back "
- "to the parent.");
+ABSL_FLAG(std::string, child_setlock_on, "",
+ "Contains the path to try to set a file lock on.");
+ABSL_FLAG(bool, child_setlock_write, false,
+ "Whether to set a writable lock (otherwise readable)");
+ABSL_FLAG(bool, blocking, false,
+ "Whether to set a blocking lock (otherwise non-blocking).");
+ABSL_FLAG(bool, retry_eintr, false,
+ "Whether to retry in the subprocess on EINTR.");
+ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start");
+ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len");
+ABSL_FLAG(int32_t, socket_fd, -1,
+ "A socket to use for communicating more state back "
+ "to the parent.");
namespace gvisor {
namespace testing {
@@ -918,25 +922,26 @@ TEST(FcntlTest, GetOwn) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (!FLAGS_child_setlock_on.empty()) {
- int socket_fd = FLAGS_socket_fd;
- int fd = open(FLAGS_child_setlock_on.c_str(), O_RDWR, 0666);
+ const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on);
+ if (!setlock_on.empty()) {
+ int socket_fd = absl::GetFlag(FLAGS_socket_fd);
+ int fd = open(setlock_on.c_str(), O_RDWR, 0666);
if (fd == -1 && errno != 0) {
int err = errno;
- std::cerr << "CHILD open " << FLAGS_child_setlock_on << " failed " << err
+ std::cerr << "CHILD open " << setlock_on << " failed " << err
<< std::endl;
exit(err);
}
struct flock fl;
- if (FLAGS_child_setlock_write) {
+ if (absl::GetFlag(FLAGS_child_setlock_write)) {
fl.l_type = F_WRLCK;
} else {
fl.l_type = F_RDLCK;
}
fl.l_whence = SEEK_SET;
- fl.l_start = FLAGS_child_setlock_start;
- fl.l_len = FLAGS_child_setlock_len;
+ fl.l_start = absl::GetFlag(FLAGS_child_setlock_start);
+ fl.l_len = absl::GetFlag(FLAGS_child_setlock_len);
// Test the fcntl, no need to log, the error is unambiguously
// from fcntl at this point.
@@ -946,8 +951,8 @@ int main(int argc, char** argv) {
gvisor::testing::MonotonicTimer timer;
timer.Start();
do {
- ret = fcntl(fd, FLAGS_blocking ? F_SETLKW : F_SETLK, &fl);
- } while (FLAGS_retry_eintr && ret == -1 && errno == EINTR);
+ ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
+ } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
auto usec = absl::ToInt64Microseconds(timer.Duration());
if (ret == -1 && errno != 0) {
diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc
index d2cbbdb49..d3e3f998c 100644
--- a/test/syscalls/linux/futex.cc
+++ b/test/syscalls/linux/futex.cc
@@ -125,6 +125,10 @@ int futex_lock_pi(bool priv, std::atomic<int>* uaddr) {
if (priv) {
op |= FUTEX_PRIVATE_FLAG;
}
+ int zero = 0;
+ if (uaddr->compare_exchange_strong(zero, gettid())) {
+ return 0;
+ }
return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
}
@@ -133,6 +137,10 @@ int futex_trylock_pi(bool priv, std::atomic<int>* uaddr) {
if (priv) {
op |= FUTEX_PRIVATE_FLAG;
}
+ int zero = 0;
+ if (uaddr->compare_exchange_strong(zero, gettid())) {
+ return 0;
+ }
return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
}
@@ -141,6 +149,10 @@ int futex_unlock_pi(bool priv, std::atomic<int>* uaddr) {
if (priv) {
op |= FUTEX_PRIVATE_FLAG;
}
+ int tid = gettid();
+ if (uaddr->compare_exchange_strong(tid, 0)) {
+ return 0;
+ }
return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
}
@@ -692,8 +704,8 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) {
std::unique_ptr<ScopedThread> threads[10];
for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) {
threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] {
- for (size_t j = 0; j < 100;) {
- if (futex_trylock_pi(is_priv, &a) >= 0) {
+ for (size_t j = 0; j < 10;) {
+ if (futex_trylock_pi(is_priv, &a) == 0) {
++j;
EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid());
SleepSafe(absl::Milliseconds(5));
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index c73262e72..57e99596f 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -23,6 +23,16 @@
namespace gvisor {
namespace testing {
+uint32_t IPFromInetSockaddr(const struct sockaddr* addr) {
+ auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
+ return in_addr->sin_addr.s_addr;
+}
+
+uint16_t PortFromInetSockaddr(const struct sockaddr* addr) {
+ auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
+ return ntohs(in_addr->sin_port);
+}
+
PosixErrorOr<int> InterfaceIndex(std::string name) {
// TODO(igudger): Consider using netlink.
ifreq req = {};
@@ -112,6 +122,14 @@ SocketKind IPv4UDPUnboundSocket(int type) {
UnboundSocketCreator(AF_INET, type | SOCK_DGRAM, IPPROTO_UDP)};
}
+SocketKind IPv6UDPUnboundSocket(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv6 UDP socket");
+ return SocketKind{
+ description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP,
+ UnboundSocketCreator(AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP)};
+}
+
SocketKind IPv4TCPUnboundSocket(int type) {
std::string description =
absl::StrCat(DescribeSocketType(type), "IPv4 TCP socket");
@@ -120,6 +138,14 @@ SocketKind IPv4TCPUnboundSocket(int type) {
UnboundSocketCreator(AF_INET, type | SOCK_STREAM, IPPROTO_TCP)};
}
+SocketKind IPv6TCPUnboundSocket(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "IPv6 TCP socket");
+ return SocketKind{
+ description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ UnboundSocketCreator(AF_INET6, type | SOCK_STREAM, IPPROTO_TCP)};
+}
+
PosixError IfAddrHelper::Load() {
Release();
RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_));
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index b498a053d..072230d85 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -26,6 +26,31 @@
namespace gvisor {
namespace testing {
+// Possible values of the "st" field in a /proc/net/{tcp,udp} entry. Source:
+// Linux kernel, include/net/tcp_states.h.
+enum {
+ TCP_ESTABLISHED = 1,
+ TCP_SYN_SENT,
+ TCP_SYN_RECV,
+ TCP_FIN_WAIT1,
+ TCP_FIN_WAIT2,
+ TCP_TIME_WAIT,
+ TCP_CLOSE,
+ TCP_CLOSE_WAIT,
+ TCP_LAST_ACK,
+ TCP_LISTEN,
+ TCP_CLOSING,
+ TCP_NEW_SYN_RECV,
+
+ TCP_MAX_STATES
+};
+
+// Extracts the IP address from an inet sockaddr in network byte order.
+uint32_t IPFromInetSockaddr(const struct sockaddr* addr);
+
+// Extracts the port from an inet sockaddr in host byte order.
+uint16_t PortFromInetSockaddr(const struct sockaddr* addr);
+
// InterfaceIndex returns the index of the named interface.
PosixErrorOr<int> InterfaceIndex(std::string name);
@@ -67,10 +92,18 @@ SocketPairKind IPv4UDPUnboundSocketPair(int type);
// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type.
SocketKind IPv4UDPUnboundSocket(int type);
+// IPv6UDPUnboundSocketPair returns a SocketKind that represents
+// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type.
+SocketKind IPv6UDPUnboundSocket(int type);
+
// IPv4TCPUnboundSocketPair returns a SocketKind that represents
// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type.
SocketKind IPv4TCPUnboundSocket(int type);
+// IPv6TCPUnboundSocketPair returns a SocketKind that represents
+// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type.
+SocketKind IPv6TCPUnboundSocket(int type);
+
// IfAddrHelper is a helper class that determines the local interfaces present
// and provides functions to obtain their names, index numbers, and IP address.
class IfAddrHelper {
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index 51ce323b9..930d2b940 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -336,7 +336,9 @@ int main(int argc, char** argv) {
}
if (arg == gvisor::testing::kSIGPROFFairnessIdle) {
MaskSIGPIPE();
- return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(10));
+ // Sleep time > ClockTick (10ms) exercises sleeping gVisor's
+ // kernel.cpuClockTicker.
+ return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(25));
}
}
diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc
index 18ad923b8..db29bd59c 100644
--- a/test/syscalls/linux/kill.cc
+++ b/test/syscalls/linux/kill.cc
@@ -21,6 +21,7 @@
#include <csignal>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
@@ -31,8 +32,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
-DEFINE_int32(scratch_gid, 65534, "scratch GID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "scratch GID");
using ::testing::Ge;
@@ -255,8 +256,8 @@ TEST(KillTest, ProcessGroups) {
TEST(KillTest, ChildDropsPrivsCannotKill) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
- int uid = FLAGS_scratch_uid;
- int gid = FLAGS_scratch_gid;
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
// Create the child that drops privileges and tries to kill the parent.
pid_t pid = fork();
@@ -331,8 +332,8 @@ TEST(KillTest, CanSIGCONTSameSession) {
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
<< "status " << status;
- int uid = FLAGS_scratch_uid;
- int gid = FLAGS_scratch_gid;
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
// Drop privileges only in child process, or else this parent process won't be
// able to open some log files after the test ends.
diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc
index a91703070..dd5352954 100644
--- a/test/syscalls/linux/link.cc
+++ b/test/syscalls/linux/link.cc
@@ -22,6 +22,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -31,7 +32,7 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
namespace gvisor {
namespace testing {
@@ -92,7 +93,8 @@ TEST(LinkTest, PermissionDenied) {
// threads have the same UIDs, so using the setuid wrapper sets all threads'
// real UID.
// Also drops capabilities.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()),
SyscallFailsWithErrno(EPERM));
diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc
index aee4f7d1a..283c21ed3 100644
--- a/test/syscalls/linux/mlock.cc
+++ b/test/syscalls/linux/mlock.cc
@@ -169,26 +169,24 @@ TEST(MlockallTest, Future) {
// Run this test in a separate (single-threaded) subprocess to ensure that a
// background thread doesn't try to mmap a large amount of memory, fail due
// to hitting RLIMIT_MEMLOCK, and explode the process violently.
- EXPECT_THAT(InForkedProcess([] {
- auto const mapping =
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)
- .ValueOrDie();
- TEST_CHECK(!IsPageMlocked(mapping.addr()));
- TEST_PCHECK(mlockall(MCL_FUTURE) == 0);
- // Ensure that mlockall(MCL_FUTURE) is turned off before the end
- // of the test, as otherwise mmaps may fail unexpectedly.
- Cleanup do_munlockall([] { TEST_PCHECK(munlockall() == 0); });
- auto const mapping2 = ASSERT_NO_ERRNO_AND_VALUE(
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- TEST_CHECK(IsPageMlocked(mapping2.addr()));
- // Fire munlockall() and check that it disables
- // mlockall(MCL_FUTURE).
- do_munlockall.Release()();
- auto const mapping3 = ASSERT_NO_ERRNO_AND_VALUE(
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- TEST_CHECK(!IsPageMlocked(mapping2.addr()));
- }),
- IsPosixErrorOkAndHolds(0));
+ auto const do_test = [] {
+ auto const mapping =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(!IsPageMlocked(mapping.addr()));
+ TEST_PCHECK(mlockall(MCL_FUTURE) == 0);
+ // Ensure that mlockall(MCL_FUTURE) is turned off before the end of the
+ // test, as otherwise mmaps may fail unexpectedly.
+ Cleanup do_munlockall([] { TEST_PCHECK(munlockall() == 0); });
+ auto const mapping2 =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(IsPageMlocked(mapping2.addr()));
+ // Fire munlockall() and check that it disables mlockall(MCL_FUTURE).
+ do_munlockall.Release()();
+ auto const mapping3 =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(!IsPageMlocked(mapping2.addr()));
+ };
+ EXPECT_THAT(InForkedProcess(do_test), IsPosixErrorOkAndHolds(0));
}
TEST(MunlockallTest, Basic) {
diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc
index 64e435cb7..f0e5f7d82 100644
--- a/test/syscalls/linux/mremap.cc
+++ b/test/syscalls/linux/mremap.cc
@@ -35,17 +35,6 @@ namespace testing {
namespace {
-// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of
-// void* isn't directly compatible with SyscallSucceeds.
-PosixErrorOr<void*> Mremap(void* old_address, size_t old_size, size_t new_size,
- int flags, void* new_address) {
- void* rv = mremap(old_address, old_size, new_size, flags, new_address);
- if (rv == MAP_FAILED) {
- return PosixError(errno, "mremap failed");
- }
- return rv;
-}
-
// Fixture for mremap tests parameterized by mmap flags.
using MremapParamTest = ::testing::TestWithParam<int>;
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index e0525f386..2b1df52ce 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -21,6 +21,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
#include "test/syscalls/linux/file_base.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 7a3379b9e..37b4e6575 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -83,9 +83,15 @@ void SendUDPMessage(int sock) {
// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
TEST(BasicCookedPacketTest, WrongType) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
FileDescriptor sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP));
@@ -118,18 +124,27 @@ class CookedPacketTest : public ::testing::TestWithParam<int> {
};
void CookedPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
SyscallSucceeds());
}
void CookedPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int CookedPacketTest::GetLoopbackIndex() {
@@ -142,9 +157,6 @@ int CookedPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(CookedPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -201,9 +213,6 @@ TEST_P(CookedPacketTest, Receive) {
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 9e96460ee..6491453b6 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -97,9 +97,15 @@ class RawPacketTest : public ::testing::TestWithParam<int> {
};
void RawPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
if (!IsRunningOnGvisor()) {
FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
@@ -119,10 +125,13 @@ void RawPacketTest::SetUp() {
}
void RawPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int RawPacketTest::GetLoopbackIndex() {
@@ -135,9 +144,6 @@ int RawPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(RawPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -208,9 +214,6 @@ TEST_P(RawPacketTest, Receive) {
// Send via a packet socket.
TEST_P(RawPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index 65afb90f3..10e2a6dfc 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -168,6 +168,20 @@ TEST_P(PipeTest, Write) {
EXPECT_EQ(wbuf, rbuf);
}
+TEST_P(PipeTest, WritePage) {
+ SKIP_IF(!CreateBlocking());
+
+ std::vector<char> wbuf(kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ std::vector<char> rbuf(wbuf.size());
+
+ ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+ ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0);
+}
+
TEST_P(PipeTest, NonBlocking) {
SKIP_IF(!CreateNonBlocking());
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index bd1779557..d07571a5f 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -21,6 +21,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
#include "test/util/multiprocess_util.h"
@@ -28,9 +29,9 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(prctl_no_new_privs_test_child, false,
- "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) "
- "plus an offset (see test source).");
+ABSL_FLAG(bool, prctl_no_new_privs_test_child, false,
+ "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) "
+ "plus an offset (see test source).");
namespace gvisor {
namespace testing {
@@ -220,7 +221,7 @@ TEST(PrctlTest, RootDumpability) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_prctl_no_new_privs_test_child) {
+ if (absl::GetFlag(FLAGS_prctl_no_new_privs_test_child)) {
exit(gvisor::testing::kPrctlNoNewPrivsTestChildExitBase +
prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0));
}
diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc
index 00dd6523e..30f0d75b3 100644
--- a/test/syscalls/linux/prctl_setuid.cc
+++ b/test/syscalls/linux/prctl_setuid.cc
@@ -14,9 +14,11 @@
#include <sched.h>
#include <sys/prctl.h>
+
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/logging.h"
#include "test/util/multiprocess_util.h"
@@ -24,12 +26,12 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
// This flag is used to verify that after an exec PR_GET_KEEPCAPS
// returns 0, the return code will be offset by kPrGetKeepCapsExitBase.
-DEFINE_bool(prctl_pr_get_keepcaps, false,
- "If true the test will verify that prctl with pr_get_keepcaps"
- "returns 0. The test will exit with the result of that check.");
+ABSL_FLAG(bool, prctl_pr_get_keepcaps, false,
+ "If true the test will verify that prctl with pr_get_keepcaps"
+ "returns 0. The test will exit with the result of that check.");
// These tests exist seperately from prctl because we need to start
// them as root. Setuid() has the behavior that permissions are fully
@@ -113,10 +115,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidNoKeepCaps) {
// call to only apply to this task. POSIX threads, however, require that
// all threads have the same UIDs, so using the setuid wrapper sets all
// threads' real UID.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
// Verify that we changed uid.
- EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid));
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
// Verify we lost the capability in the effective set, this always happens.
TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
@@ -157,10 +161,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidKeepCaps) {
// call to only apply to this task. POSIX threads, however, require that
// all threads have the same UIDs, so using the setuid wrapper sets all
// threads' real UID.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
// Verify that we changed uid.
- EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid));
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
// Verify we lost the capability in the effective set, this always happens.
TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
@@ -253,7 +259,7 @@ TEST_F(PrctlKeepCapsSetuidTest, PrGetKeepCaps) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_prctl_pr_get_keepcaps) {
+ if (absl::GetFlag(FLAGS_prctl_pr_get_keepcaps)) {
return gvisor::testing::kPrGetKeepCapsExitBase +
prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0);
}
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index b440ba0df..e4c030bbb 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -440,6 +440,11 @@ TEST(ProcSelfAuxv, EntryPresence) {
EXPECT_EQ(auxv_entries.count(AT_PHENT), 1);
EXPECT_EQ(auxv_entries.count(AT_PHNUM), 1);
EXPECT_EQ(auxv_entries.count(AT_BASE), 1);
+ EXPECT_EQ(auxv_entries.count(AT_UID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EUID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_GID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EGID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_SECURE), 1);
EXPECT_EQ(auxv_entries.count(AT_CLKTCK), 1);
EXPECT_EQ(auxv_entries.count(AT_RANDOM), 1);
EXPECT_EQ(auxv_entries.count(AT_EXECFN), 1);
@@ -1602,9 +1607,9 @@ class BlockingChild {
}
mutable absl::Mutex mu_;
- bool stop_ GUARDED_BY(mu_) = false;
+ bool stop_ ABSL_GUARDED_BY(mu_) = false;
pid_t tid_;
- bool tid_ready_ GUARDED_BY(mu_) = false;
+ bool tid_ready_ ABSL_GUARDED_BY(mu_) = false;
// Must be last to ensure that the destructor for the thread is run before
// any other member of the object is destroyed.
@@ -1882,7 +1887,9 @@ void CheckDuplicatesRecursively(std::string path) {
errno = 0;
DIR* dir = opendir(path.c_str());
if (dir == nullptr) {
- ASSERT_THAT(errno, ::testing::AnyOf(EPERM, EACCES)) << path;
+ // Ignore any directories we can't read or missing directories as the
+ // directory could have been deleted/mutated from the time the parent
+ // directory contents were read.
return;
}
auto dir_closer = Cleanup([&dir]() { closedir(dir); });
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 03d0665eb..efdaf202b 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -14,6 +14,7 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/test_util.h"
@@ -27,7 +28,7 @@ TEST(ProcNetIfInet6, Format) {
EXPECT_THAT(ifinet6,
::testing::MatchesRegex(
// Ex: "00000000000000000000000000000001 01 80 10 80 lo\n"
- "^([a-f\\d]{32}( [a-f\\d]{2}){4} +[a-z][a-z\\d]*\\n)+$"));
+ "^([a-f0-9]{32}( [a-f0-9]{2}){4} +[a-z][a-z0-9]*\n)+$"));
}
TEST(ProcSysNetIpv4Sack, Exists) {
@@ -35,6 +36,8 @@ TEST(ProcSysNetIpv4Sack, Exists) {
}
TEST(ProcSysNetIpv4Sack, CanReadAndWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE))));
+
auto const fd =
ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/tcp_sack", O_RDWR));
diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc
index 498f62d9c..f61795592 100644
--- a/test/syscalls/linux/proc_net_tcp.cc
+++ b/test/syscalls/linux/proc_net_tcp.cc
@@ -38,25 +38,6 @@ constexpr char kProcNetTCPHeader[] =
"retrnsmt uid timeout inode "
" ";
-// Possible values of the "st" field in a /proc/net/tcp entry. Source: Linux
-// kernel, include/net/tcp_states.h.
-enum {
- TCP_ESTABLISHED = 1,
- TCP_SYN_SENT,
- TCP_SYN_RECV,
- TCP_FIN_WAIT1,
- TCP_FIN_WAIT2,
- TCP_TIME_WAIT,
- TCP_CLOSE,
- TCP_CLOSE_WAIT,
- TCP_LAST_ACK,
- TCP_LISTEN,
- TCP_CLOSING,
- TCP_NEW_SYN_RECV,
-
- TCP_MAX_STATES
-};
-
// TCPEntry represents a single entry from /proc/net/tcp.
struct TCPEntry {
uint32_t local_addr;
@@ -70,42 +51,35 @@ struct TCPEntry {
uint64_t inode;
};
-uint32_t IP(const struct sockaddr* addr) {
- auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
- return in_addr->sin_addr.s_addr;
-}
-
-uint16_t Port(const struct sockaddr* addr) {
- auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
- return ntohs(in_addr->sin_port);
-}
-
// Finds the first entry in 'entries' for which 'predicate' returns true.
-// Returns true on match, and sets 'match' to point to the matching entry.
-bool FindBy(std::vector<TCPEntry> entries, TCPEntry* match,
+// Returns true on match, and sets 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy(const std::vector<TCPEntry>& entries, TCPEntry* match,
std::function<bool(const TCPEntry&)> predicate) {
- for (int i = 0; i < entries.size(); ++i) {
- if (predicate(entries[i])) {
- *match = entries[i];
+ for (const TCPEntry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
return true;
}
}
return false;
}
-bool FindByLocalAddr(std::vector<TCPEntry> entries, TCPEntry* match,
+bool FindByLocalAddr(const std::vector<TCPEntry>& entries, TCPEntry* match,
const struct sockaddr* addr) {
- uint32_t host = IP(addr);
- uint16_t port = Port(addr);
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
return FindBy(entries, match, [host, port](const TCPEntry& e) {
return (e.local_addr == host && e.local_port == port);
});
}
-bool FindByRemoteAddr(std::vector<TCPEntry> entries, TCPEntry* match,
+bool FindByRemoteAddr(const std::vector<TCPEntry>& entries, TCPEntry* match,
const struct sockaddr* addr) {
- uint32_t host = IP(addr);
- uint16_t port = Port(addr);
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
return FindBy(entries, match, [host, port](const TCPEntry& e) {
return (e.remote_addr == host && e.remote_port == port);
});
@@ -120,7 +94,7 @@ PosixErrorOr<std::vector<TCPEntry>> ProcNetTCPEntries() {
std::vector<TCPEntry> entries;
std::vector<std::string> lines = StrSplit(content, '\n');
std::cerr << "<contents of /proc/net/tcp>" << std::endl;
- for (std::string line : lines) {
+ for (const std::string& line : lines) {
std::cerr << line << std::endl;
if (!found_header) {
@@ -204,9 +178,8 @@ TEST(ProcNetTCP, BindAcceptConnect) {
EXPECT_EQ(entries.size(), 2);
}
- TCPEntry e;
- EXPECT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()));
- EXPECT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr()));
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()));
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr()));
}
TEST(ProcNetTCP, InodeReasonable) {
@@ -261,8 +234,8 @@ TEST(ProcNetTCP, State) {
FileDescriptor accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr));
- const uint32_t accepted_local_host = IP(&addr);
- const uint16_t accepted_local_port = Port(&addr);
+ const uint32_t accepted_local_host = IPFromInetSockaddr(&addr);
+ const uint16_t accepted_local_port = PortFromInetSockaddr(&addr);
entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
TCPEntry accepted_entry;
@@ -277,6 +250,247 @@ TEST(ProcNetTCP, State) {
EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED);
}
+constexpr char kProcNetTCP6Header[] =
+ " sl local_address remote_address"
+ " st tx_queue rx_queue tr tm->when retrnsmt"
+ " uid timeout inode";
+
+// TCP6Entry represents a single entry from /proc/net/tcp6.
+struct TCP6Entry {
+ struct in6_addr local_addr;
+ uint16_t local_port;
+
+ struct in6_addr remote_addr;
+ uint16_t remote_port;
+
+ uint64_t state;
+ uint64_t uid;
+ uint64_t inode;
+};
+
+bool IPv6AddrEqual(const struct in6_addr* a1, const struct in6_addr* a2) {
+ return memcmp(a1, a2, sizeof(struct in6_addr)) == 0;
+}
+
+// Finds the first entry in 'entries' for which 'predicate' returns true.
+// Returns true on match, and sets 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
+ std::function<bool(const TCP6Entry&)> predicate) {
+ for (const TCP6Entry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+const struct in6_addr* IP6FromInetSockaddr(const struct sockaddr* addr) {
+ auto* addr6 = reinterpret_cast<const struct sockaddr_in6*>(addr);
+ return &addr6->sin6_addr;
+}
+
+bool FindByLocalAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
+ const struct sockaddr* addr) {
+ const struct in6_addr* local = IP6FromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy6(entries, match, [local, port](const TCP6Entry& e) {
+ return (IPv6AddrEqual(&e.local_addr, local) && e.local_port == port);
+ });
+}
+
+bool FindByRemoteAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
+ const struct sockaddr* addr) {
+ const struct in6_addr* remote = IP6FromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy6(entries, match, [remote, port](const TCP6Entry& e) {
+ return (IPv6AddrEqual(&e.remote_addr, remote) && e.remote_port == port);
+ });
+}
+
+void ReadIPv6Address(std::string s, struct in6_addr* addr) {
+ uint32_t a0, a1, a2, a3;
+ const char* fmt = "%08X%08X%08X%08X";
+ EXPECT_EQ(sscanf(s.c_str(), fmt, &a0, &a1, &a2, &a3), 4);
+
+ uint8_t* b = addr->s6_addr;
+ *((uint32_t*)&b[0]) = a0;
+ *((uint32_t*)&b[4]) = a1;
+ *((uint32_t*)&b[8]) = a2;
+ *((uint32_t*)&b[12]) = a3;
+}
+
+// Returns a parsed representation of /proc/net/tcp6 entries.
+PosixErrorOr<std::vector<TCP6Entry>> ProcNetTCP6Entries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/net/tcp6", &content));
+
+ bool found_header = false;
+ std::vector<TCP6Entry> entries;
+ std::vector<std::string> lines = StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/net/tcp6>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcNetTCP6Header);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/net/tcp6.
+ //
+ // Example entries:
+ //
+ // clang-format off
+ //
+ // sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
+ // 0: 00000000000000000000000000000000:1F90 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876340 1 ffff8803da9c9380 100 0 0 10 0
+ // 1: 00000000000000000000000000000000:C350 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876987 1 ffff8803ec408000 100 0 0 10 0
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
+ //
+ // clang-format on
+
+ TCP6Entry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty());
+
+ ReadIPv6Address(fields[1], &entry.local_addr);
+ ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16));
+ ReadIPv6Address(fields[3], &entry.remote_addr);
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11]));
+ ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13]));
+
+ entries.push_back(entry);
+ }
+ std::cerr << "<end of /proc/net/tcp6>" << std::endl;
+
+ return entries;
+}
+
+TEST(ProcNetTCP6, Exists) {
+ const std::string content =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/tcp6"));
+ const std::string header_line = StrCat(kProcNetTCP6Header, "\n");
+ if (IsRunningOnGvisor()) {
+ // Should be just the header since we don't have any tcp sockets yet.
+ EXPECT_EQ(content, header_line);
+ } else {
+ // On a general linux machine, we could have abitrary sockets on the system,
+ // so just check the header.
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
+ }
+}
+
+TEST(ProcNetTCP6, EntryUID) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ TCP6Entry e;
+
+ ASSERT_TRUE(FindByLocalAddr6(entries, &e, sockets->first_addr()));
+ EXPECT_EQ(e.uid, geteuid());
+ ASSERT_TRUE(FindByRemoteAddr6(entries, &e, sockets->first_addr()));
+ EXPECT_EQ(e.uid, geteuid());
+}
+
+TEST(ProcNetTCP6, BindAcceptConnect) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ // We can only make assertions about the total number of entries if we control
+ // the entire "machine".
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(entries.size(), 2);
+ }
+
+ EXPECT_TRUE(FindByLocalAddr6(entries, nullptr, sockets->first_addr()));
+ EXPECT_TRUE(FindByRemoteAddr6(entries, nullptr, sockets->first_addr()));
+}
+
+TEST(ProcNetTCP6, InodeReasonable) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+
+ TCP6Entry accepted_entry;
+
+ ASSERT_TRUE(
+ FindByLocalAddr6(entries, &accepted_entry, sockets->first_addr()));
+ EXPECT_NE(accepted_entry.inode, 0);
+
+ TCP6Entry client_entry;
+ ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, sockets->first_addr()));
+ EXPECT_NE(client_entry.inode, 0);
+ EXPECT_NE(accepted_entry.inode, client_entry.inode);
+}
+
+TEST(ProcNetTCP6, State) {
+ std::unique_ptr<FileDescriptor> server =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create());
+
+ auto test_addr = V6Loopback();
+ ASSERT_THAT(
+ bind(server->get(), reinterpret_cast<struct sockaddr*>(&test_addr.addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ struct sockaddr_in6 addr6;
+ socklen_t addrlen = sizeof(struct sockaddr_in6);
+ auto* addr = reinterpret_cast<struct sockaddr*>(&addr6);
+ ASSERT_THAT(getsockname(server->get(), addr, &addrlen), SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ ASSERT_THAT(listen(server->get(), 10), SyscallSucceeds());
+ std::vector<TCP6Entry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ TCP6Entry listen_entry;
+
+ ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr));
+ EXPECT_EQ(listen_entry.state, TCP_LISTEN);
+
+ std::unique_ptr<FileDescriptor> client =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create());
+ ASSERT_THAT(RetryEINTR(connect)(client->get(), addr, addrlen),
+ SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr));
+ EXPECT_EQ(listen_entry.state, TCP_LISTEN);
+ TCP6Entry client_entry;
+ ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, addr));
+ EXPECT_EQ(client_entry.state, TCP_ESTABLISHED);
+
+ FileDescriptor accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr));
+
+ const struct in6_addr* local = IP6FromInetSockaddr(addr);
+ const uint16_t accepted_local_port = PortFromInetSockaddr(addr);
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
+ TCP6Entry accepted_entry;
+ ASSERT_TRUE(FindBy6(
+ entries, &accepted_entry,
+ [client_entry, local, accepted_local_port](const TCP6Entry& e) {
+ return IPv6AddrEqual(&e.local_addr, local) &&
+ e.local_port == accepted_local_port &&
+ IPv6AddrEqual(&e.remote_addr, &client_entry.local_addr) &&
+ e.remote_port == client_entry.local_port;
+ }));
+ EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED);
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc
new file mode 100644
index 000000000..369df8e0e
--- /dev/null
+++ b/test/syscalls/linux/proc_net_udp.cc
@@ -0,0 +1,309 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using absl::StrCat;
+using absl::StrFormat;
+using absl::StrSplit;
+
+constexpr char kProcNetUDPHeader[] =
+ " sl local_address rem_address st tx_queue rx_queue tr tm->when "
+ "retrnsmt uid timeout inode ref pointer drops ";
+
+// UDPEntry represents a single entry from /proc/net/udp.
+struct UDPEntry {
+ uint32_t local_addr;
+ uint16_t local_port;
+
+ uint32_t remote_addr;
+ uint16_t remote_port;
+
+ uint64_t state;
+ uint64_t uid;
+ uint64_t inode;
+};
+
+std::string DescribeFirstInetSocket(const SocketPair& sockets) {
+ const struct sockaddr* addr = sockets.first_addr();
+ return StrFormat("First test socket: fd:%d %8X:%4X", sockets.first_fd(),
+ IPFromInetSockaddr(addr), PortFromInetSockaddr(addr));
+}
+
+std::string DescribeSecondInetSocket(const SocketPair& sockets) {
+ const struct sockaddr* addr = sockets.second_addr();
+ return StrFormat("Second test socket fd:%d %8X:%4X", sockets.second_fd(),
+ IPFromInetSockaddr(addr), PortFromInetSockaddr(addr));
+}
+
+// Finds the first entry in 'entries' for which 'predicate' returns true.
+// Returns true on match, and set 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ std::function<bool(const UDPEntry&)> predicate) {
+ for (const UDPEntry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+bool FindByLocalAddr(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const UDPEntry& e) {
+ return (e.local_addr == host && e.local_port == port);
+ });
+}
+
+bool FindByRemoteAddr(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const UDPEntry& e) {
+ return (e.remote_addr == host && e.remote_port == port);
+ });
+}
+
+PosixErrorOr<uint64_t> InodeFromSocketFD(int fd) {
+ ASSIGN_OR_RETURN_ERRNO(struct stat s, Fstat(fd));
+ if (!S_ISSOCK(s.st_mode)) {
+ return PosixError(EINVAL, StrFormat("FD %d is not a socket", fd));
+ }
+ return s.st_ino;
+}
+
+PosixErrorOr<bool> FindByFD(const std::vector<UDPEntry>& entries,
+ UDPEntry* match, int fd) {
+ ASSIGN_OR_RETURN_ERRNO(uint64_t inode, InodeFromSocketFD(fd));
+ return FindBy(entries, match,
+ [inode](const UDPEntry& e) { return (e.inode == inode); });
+}
+
+// Returns a parsed representation of /proc/net/udp entries.
+PosixErrorOr<std::vector<UDPEntry>> ProcNetUDPEntries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/net/udp", &content));
+
+ bool found_header = false;
+ std::vector<UDPEntry> entries;
+ std::vector<std::string> lines = StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/net/udp>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcNetUDPHeader);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/net/udp.
+ //
+ // Example entries:
+ //
+ // clang-format off
+ //
+ // sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops
+ // 3503: 0100007F:0035 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 33317 2 0000000000000000 0
+ // 3518: 00000000:0044 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 40394 2 0000000000000000 0
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
+ //
+ // clang-format on
+
+ UDPEntry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty());
+
+ ASSIGN_OR_RETURN_ERRNO(entry.local_addr, AtoiBase(fields[1], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_addr, AtoiBase(fields[3], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11]));
+ ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13]));
+
+ // Linux shares internal data structures between TCP and UDP sockets. The
+ // proc entries for UDP sockets share some fields with TCP sockets, but
+ // these fields should always be zero as they're not meaningful for UDP
+ // sockets.
+ EXPECT_EQ(fields[8], "00") << StrFormat("sl:%s, tr", fields[0]);
+ EXPECT_EQ(fields[9], "00000000") << StrFormat("sl:%s, tm->when", fields[0]);
+ EXPECT_EQ(fields[10], "00000000")
+ << StrFormat("sl:%s, retrnsmt", fields[0]);
+ EXPECT_EQ(fields[12], "0") << StrFormat("sl:%s, timeout", fields[0]);
+
+ entries.push_back(entry);
+ }
+ std::cerr << "<end of /proc/net/udp>" << std::endl;
+
+ return entries;
+}
+
+TEST(ProcNetUDP, Exists) {
+ const std::string content =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/udp"));
+ const std::string header_line = StrCat(kProcNetUDPHeader, "\n");
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
+}
+
+TEST(ProcNetUDP, EntryUID) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_EQ(e.uid, geteuid());
+ ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_EQ(e.uid, geteuid());
+}
+
+TEST(ProcNetUDP, FindMutualEntries) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeFirstInetSocket(*sockets);
+}
+
+TEST(ProcNetUDP, EntriesRemovedOnClose) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ // First socket's entry should be gone, but the second socket's entry should
+ // still exist.
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ // Both entries should be gone.
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+}
+
+PosixErrorOr<std::unique_ptr<FileDescriptor>> BoundUDPSocket() {
+ ASSIGN_OR_RETURN_ERRNO(std::unique_ptr<FileDescriptor> socket,
+ IPv4UDPUnboundSocket(0).Create());
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_ANY);
+ addr.sin_port = 0;
+
+ int res = bind(socket->get(), reinterpret_cast<const struct sockaddr*>(&addr),
+ sizeof(addr));
+ if (res) {
+ return PosixError(errno, "bind()");
+ }
+ return socket;
+}
+
+TEST(ProcNetUDP, BoundEntry) {
+ std::unique_ptr<FileDescriptor> socket =
+ ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket());
+ struct sockaddr addr;
+ socklen_t len = sizeof(addr);
+ ASSERT_THAT(getsockname(socket->get(), &addr, &len), SyscallSucceeds());
+ uint16_t port = PortFromInetSockaddr(&addr);
+
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get())));
+ EXPECT_EQ(e.local_port, port);
+ EXPECT_EQ(e.remote_addr, 0);
+ EXPECT_EQ(e.remote_port, 0);
+}
+
+TEST(ProcNetUDP, BoundSocketStateClosed) {
+ std::unique_ptr<FileDescriptor> socket =
+ ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get())));
+ EXPECT_EQ(e.state, TCP_CLOSE);
+}
+
+TEST(ProcNetUDP, ConnectedSocketStateEstablished) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ UDPEntry e;
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_EQ(e.state, TCP_ESTABLISHED);
+
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_EQ(e.state, TCP_ESTABLISHED);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
index 9b9be66ff..83dbd1364 100644
--- a/test/syscalls/linux/proc_net_unix.cc
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -56,19 +56,44 @@ struct UnixEntry {
std::string path;
};
+// Abstract socket paths can have either trailing null bytes or '@'s as padding
+// at the end, depending on the linux version. This function strips any such
+// padding.
+void StripAbstractPathPadding(std::string* s) {
+ const char pad_char = s->back();
+ if (pad_char != '\0' && pad_char != '@') {
+ return;
+ }
+
+ const auto last_pos = s->find_last_not_of(pad_char);
+ if (last_pos != std::string::npos) {
+ s->resize(last_pos + 1);
+ }
+}
+
+// Precondition: addr must be a unix socket address (i.e. sockaddr_un) and
+// addr->sun_path must be null-terminated. This is always the case if addr comes
+// from Linux:
+//
+// Per man unix(7):
+//
+// "When the address of a pathname socket is returned (by [getsockname(2)]), its
+// length is
+//
+// offsetof(struct sockaddr_un, sun_path) + strlen(sun_path) + 1
+//
+// and sun_path contains the null-terminated pathname."
std::string ExtractPath(const struct sockaddr* addr) {
const char* path =
reinterpret_cast<const struct sockaddr_un*>(addr)->sun_path;
// Note: sockaddr_un.sun_path is an embedded character array of length
// UNIX_PATH_MAX, so we can always safely dereference the first 2 bytes below.
//
- // The kernel also enforces that the path is always null terminated.
+ // We also rely on the path being null-terminated.
if (path[0] == 0) {
- // Abstract socket paths are null padded to the end of the struct
- // sockaddr. However, these null bytes may or may not show up in
- // /proc/net/unix depending on the kernel version. Truncate after the first
- // null byte (by treating path as a c-string).
- return StrCat("@", &path[1]);
+ std::string abstract_path = StrCat("@", &path[1]);
+ StripAbstractPathPadding(&abstract_path);
+ return abstract_path;
}
return std::string(path);
}
@@ -96,14 +121,6 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
continue;
}
- // Abstract socket paths can have trailing null bytes in them depending on
- // the linux version. Strip off everything after a null byte, including the
- // null byte.
- std::size_t null_pos = line.find('\0');
- if (null_pos != std::string::npos) {
- line.erase(null_pos);
- }
-
// Parse a single entry from /proc/net/unix.
//
// Sample file:
@@ -151,6 +168,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
entry.path = "";
if (fields.size() > 1) {
entry.path = fields[1];
+ StripAbstractPathPadding(&entry.path);
}
entries.push_back(entry);
@@ -200,8 +218,8 @@ TEST(ProcNetUnix, FilesystemBindAcceptConnect) {
std::string path1 = ExtractPath(sockets->first_addr());
std::string path2 = ExtractPath(sockets->second_addr());
- std::cout << StreamFormat("Server socket address: %s\n", path1);
- std::cout << StreamFormat("Client socket address: %s\n", path2);
+ std::cerr << StreamFormat("Server socket address (path1): %s\n", path1);
+ std::cerr << StreamFormat("Client socket address (path2): %s\n", path2);
std::vector<UnixEntry> entries =
ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
@@ -224,8 +242,8 @@ TEST(ProcNetUnix, AbstractBindAcceptConnect) {
std::string path1 = ExtractPath(sockets->first_addr());
std::string path2 = ExtractPath(sockets->second_addr());
- std::cout << StreamFormat("Server socket address: '%s'\n", path1);
- std::cout << StreamFormat("Client socket address: '%s'\n", path2);
+ std::cerr << StreamFormat("Server socket address (path1): '%s'\n", path1);
+ std::cerr << StreamFormat("Client socket address (path2): '%s'\n", path2);
std::vector<UnixEntry> entries =
ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index abf2b1a04..8f3800380 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -27,6 +27,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/logging.h"
@@ -36,10 +37,10 @@
#include "test/util/thread_util.h"
#include "test/util/time_util.h"
-DEFINE_bool(ptrace_test_execve_child, false,
- "If true, run the "
- "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_"
- "TraceExit child workload.");
+ABSL_FLAG(bool, ptrace_test_execve_child, false,
+ "If true, run the "
+ "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_"
+ "TraceExit child workload.");
namespace gvisor {
namespace testing {
@@ -1206,7 +1207,7 @@ TEST(PtraceTest, SeizeSetOptions) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_ptrace_test_execve_child) {
+ if (absl::GetFlag(FLAGS_ptrace_test_execve_child)) {
gvisor::testing::RunExecveChild();
}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index bd6907876..99a0df235 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1292,10 +1292,9 @@ TEST_F(JobControlTest, ReleaseTTY) {
// Make sure we're ignoring SIGHUP, which will be sent to this process once we
// disconnect they TTY.
- struct sigaction sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
struct sigaction old_sa;
EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds());
@@ -1362,10 +1361,9 @@ TEST_F(JobControlTest, ReleaseTTYSignals) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
received = 0;
- struct sigaction sa = {
- .sa_handler = sig_handler,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = sig_handler;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
sigaddset(&sa.sa_mask, SIGHUP);
sigaddset(&sa.sa_mask, SIGCONT);
@@ -1403,10 +1401,9 @@ TEST_F(JobControlTest, ReleaseTTYSignals) {
// Make sure we're ignoring SIGHUP, which will be sent to this process once we
// disconnect they TTY.
- struct sigaction sighup_sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sighup_sa = {};
+ sighup_sa.sa_handler = SIG_IGN;
+ sighup_sa.sa_flags = 0;
sigemptyset(&sighup_sa.sa_mask);
struct sigaction old_sa;
EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds());
@@ -1456,10 +1453,9 @@ TEST_F(JobControlTest, SetForegroundProcessGroup) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
// Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp.
- struct sigaction sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
sigaction(SIGTTOU, &sa, NULL);
@@ -1531,27 +1527,70 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+ int sync_setsid[2];
+ int sync_exit[2];
+ ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds());
+ ASSERT_THAT(pipe(sync_exit), SyscallSucceeds());
+
// Create a new process and put it in a new session.
pid_t child = fork();
if (!child) {
TEST_PCHECK(setsid() >= 0);
// Tell the parent we're in a new session.
- TEST_PCHECK(!raise(SIGSTOP));
- TEST_PCHECK(!pause());
- _exit(1);
+ char c = 'c';
+ TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1);
+ TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1);
+ _exit(0);
}
// Wait for the child to tell us it's in a new session.
- int wstatus;
- EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED),
- SyscallSucceedsWithValue(child));
- EXPECT_TRUE(WSTOPSIG(wstatus));
+ char c = 'c';
+ ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1));
// Child is in a new session, so we can't make it the foregroup process group.
EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
SyscallFailsWithErrno(EPERM));
- EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds());
+ EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1));
+
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(wstatus));
+ EXPECT_EQ(WEXITSTATUS(wstatus), 0);
+}
+
+// Verify that we don't hang when creating a new session from an orphaned
+// process group (b/139968068). Calling setsid() creates an orphaned process
+// group, as process groups that contain the session's leading process are
+// orphans.
+//
+// We create 2 sessions in this test. The init process in gVisor is considered
+// not to be an orphan (see sessions.go), so we have to create a session from
+// which to create a session. The latter session is being created from an
+// orphaned process group.
+TEST_F(JobControlTest, OrphanRegression) {
+ pid_t session_2_leader = fork();
+ if (!session_2_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ pid_t session_3_leader = fork();
+ if (!session_3_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ TEST_PCHECK(waitpid(session_3_leader, &wstatus, 0) == session_3_leader);
+ TEST_PCHECK(wstatus == 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(session_2_leader, &wstatus, 0),
+ SyscallSucceedsWithValue(session_2_leader));
+ ASSERT_EQ(wstatus, 0);
}
} // namespace
diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc
index db519f4e0..f6a0fc96c 100644
--- a/test/syscalls/linux/pwritev2.cc
+++ b/test/syscalls/linux/pwritev2.cc
@@ -244,8 +244,10 @@ TEST(Pwritev2Test, TestInvalidOffset) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
/*offset=*/static_cast<off_t>(-8), /*flags=*/0),
@@ -286,8 +288,10 @@ TEST(Pwritev2Test, TestUnseekableFileInValid) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
int pipe_fds[2];
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
@@ -307,8 +311,10 @@ TEST(Pwritev2Test, TestReadOnlyFile) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
/*offset=*/0, /*flags=*/0),
@@ -324,8 +330,10 @@ TEST(Pwritev2Test, TestInvalidFlag) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR | O_DIRECT));
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
/*offset=*/0, /*flags=*/0xF0),
diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc
index a070817eb..0a27506aa 100644
--- a/test/syscalls/linux/raw_socket_hdrincl.cc
+++ b/test/syscalls/linux/raw_socket_hdrincl.cc
@@ -63,7 +63,11 @@ class RawHDRINCL : public ::testing::Test {
};
void RawHDRINCL::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
SyscallSucceeds());
@@ -76,9 +80,10 @@ void RawHDRINCL::SetUp() {
}
void RawHDRINCL::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
struct iphdr RawHDRINCL::LoopbackHeader() {
@@ -123,8 +128,6 @@ bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port,
// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup
// creates the first one, so we only have to create one more here.
TEST_F(RawHDRINCL, MultipleCreation) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int s2;
ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds());
@@ -133,23 +136,17 @@ TEST_F(RawHDRINCL, MultipleCreation) {
// Test that shutting down an unconnected socket fails.
TEST_F(RawHDRINCL, FailShutdownWithoutConnect) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
}
// Test that listen() fails.
TEST_F(RawHDRINCL, FailListen) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP));
}
// Test that accept() fails.
TEST_F(RawHDRINCL, FailAccept) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr saddr;
socklen_t addrlen;
ASSERT_THAT(accept(socket_, &saddr, &addrlen),
@@ -158,8 +155,6 @@ TEST_F(RawHDRINCL, FailAccept) {
// Test that the socket is writable immediately.
TEST_F(RawHDRINCL, PollWritableImmediately) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct pollfd pfd = {};
pfd.fd = socket_;
pfd.events = POLLOUT;
@@ -168,8 +163,6 @@ TEST_F(RawHDRINCL, PollWritableImmediately) {
// Test that the socket isn't readable.
TEST_F(RawHDRINCL, NotReadable) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
// Try to receive data with MSG_DONTWAIT, which returns immediately if there's
// nothing to be read.
char buf[117];
@@ -179,16 +172,12 @@ TEST_F(RawHDRINCL, NotReadable) {
// Test that we can connect() to a valid IP (loopback).
TEST_F(RawHDRINCL, ConnectToLoopback) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
}
TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct iphdr hdr = LoopbackHeader();
ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0),
SyscallSucceedsWithValue(sizeof(hdr)));
@@ -197,8 +186,6 @@ TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
// HDRINCL implies write-only. Verify that we can't read a packet sent to
// loopback.
TEST_F(RawHDRINCL, NotReadableAfterWrite) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -221,8 +208,6 @@ TEST_F(RawHDRINCL, NotReadableAfterWrite) {
}
TEST_F(RawHDRINCL, WriteTooSmall) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -235,8 +220,6 @@ TEST_F(RawHDRINCL, WriteTooSmall) {
// Bind to localhost.
TEST_F(RawHDRINCL, BindToLocalhost) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(
bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
SyscallSucceeds());
@@ -244,8 +227,6 @@ TEST_F(RawHDRINCL, BindToLocalhost) {
// Bind to a different address.
TEST_F(RawHDRINCL, BindToInvalid) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr_in bind_addr = {};
bind_addr.sin_family = AF_INET;
bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
@@ -256,8 +237,6 @@ TEST_F(RawHDRINCL, BindToInvalid) {
// Send and receive a packet.
TEST_F(RawHDRINCL, SendAndReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -302,8 +281,6 @@ TEST_F(RawHDRINCL, SendAndReceive) {
// Send and receive a packet with nonzero IP ID.
TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -349,8 +326,6 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
// Send and receive a packet where the sendto address is not the same as the
// provided destination.
TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 971592d7d..8bcaba6f1 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -77,7 +77,11 @@ class RawSocketICMPTest : public ::testing::Test {
};
void RawSocketICMPTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
@@ -90,9 +94,10 @@ void RawSocketICMPTest::SetUp() {
}
void RawSocketICMPTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We'll only read an echo in this case, as the kernel won't respond to the
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc
index 352037c88..cde2f07c9 100644
--- a/test/syscalls/linux/raw_socket_ipv4.cc
+++ b/test/syscalls/linux/raw_socket_ipv4.cc
@@ -67,7 +67,11 @@ class RawSocketTest : public ::testing::TestWithParam<int> {
};
void RawSocketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
@@ -79,9 +83,10 @@ void RawSocketTest::SetUp() {
}
void RawSocketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We should be able to create multiple raw sockets for the same protocol.
diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc
new file mode 100644
index 000000000..09703b5c1
--- /dev/null
+++ b/test/syscalls/linux/readahead.cc
@@ -0,0 +1,91 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ReadaheadTest, InvalidFD) {
+ EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidOffset) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), -1, 1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(ReadaheadTest, ValidOffset) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // N.B. The implementation of readahead is filesystem-specific, and a file
+ // backed by ram may return EINVAL because there is nothing to be read.
+ EXPECT_THAT(readahead(fd.get(), 1, 1), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, PastEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 2, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, CrossesEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 4, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, WriteOnly) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_WRONLY));
+ EXPECT_THAT(readahead(fd.get(), 0, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidSize) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), 0, -1), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 421318fcb..40c57f543 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -15,6 +15,7 @@
#include <sys/ipc.h>
#include <sys/sem.h>
#include <sys/types.h>
+
#include <atomic>
#include <cerrno>
#include <ctime>
@@ -22,6 +23,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/base/macros.h"
+#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "test/util/capability_util.h"
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index e5d72e28a..580ab5193 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -13,15 +13,20 @@
// limitations under the License.
#include <fcntl.h>
+#include <sys/eventfd.h>
#include <sys/sendfile.h>
#include <unistd.h>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/util/eventfd_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -299,10 +304,30 @@ TEST(SendFileTest, DoNotSendfileIfOutfileIsAppendOnly) {
// Open the output file as append only.
const FileDescriptor outf =
- ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_APPEND));
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY | O_APPEND));
// Send data and verify that sendfile returns the correct errno.
EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SendFileTest, AppendCheckOrdering) {
+ constexpr char kData[] = "And by opposing end them: to die, to sleep";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ const FileDescriptor read =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ const FileDescriptor write =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+ const FileDescriptor append =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_APPEND));
+
+ // Check that read/write file mode is verified before append.
+ EXPECT_THAT(sendfile(append.get(), read.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(sendfile(write.get(), write.get(), nullptr, kDataSize),
SyscallFailsWithErrno(EBADF));
}
@@ -422,6 +447,89 @@ TEST(SendFileTest, SendToNotARegularFile) {
EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0),
SyscallFailsWithErrno(EINVAL));
}
+
+TEST(SendFileTest, SendPipeWouldBlock) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fool doth think he is wise, but the wise man knows himself to be a "
+ "fool.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(2 * pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(SendFileTest, SendPipeBlocks) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fault, dear Brutus, is not in our stars, but in ourselves.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
+TEST(SendFileTest, SendToSpecialFile) {
+ // Create temp file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
+
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ constexpr int kSize = 0x7ff;
+ ASSERT_THAT(ftruncate(inf.get(), kSize), SyscallSucceeds());
+
+ auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
+
+ // eventfd can accept a number of bytes which is a multiple of 8.
+ EXPECT_THAT(sendfile(eventfd.get(), inf.get(), nullptr, 0xfffff),
+ SyscallSucceedsWithValue(kSize & (~7)));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
new file mode 100644
index 000000000..9379d5878
--- /dev/null
+++ b/test/syscalls/linux/signalfd.cc
@@ -0,0 +1,350 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <poll.h>
+#include <signal.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/signalfd.h>
+#include <unistd.h>
+
+#include <functional>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+using ::testing::KilledBySignal;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kSigno = SIGUSR1;
+constexpr int kSignoAlt = SIGUSR2;
+
+// Returns a new signalfd.
+inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) {
+ int fd = signalfd(-1, mask, flags);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "signalfd");
+ }
+ return FileDescriptor(fd);
+}
+
+TEST(Signalfd, Basic) {
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Deliver the blocked signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // We should now read the signal.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, MaskWorks) {
+ // Create two signalfds with different masks.
+ sigset_t mask1, mask2;
+ sigemptyset(&mask1);
+ sigemptyset(&mask2);
+ sigaddset(&mask1, kSigno);
+ sigaddset(&mask2, kSignoAlt);
+ FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0));
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0));
+
+ // Deliver the two signals.
+ const auto scoped_sigmask1 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ const auto scoped_sigmask2 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds());
+
+ // We should see the signals on the appropriate signalfds.
+ //
+ // We read in the opposite order as the signals deliver above, to ensure that
+ // we don't happen to read the correct signal from the correct signalfd.
+ struct signalfd_siginfo rbuf1, rbuf2;
+ ASSERT_THAT(read(fd2.get(), &rbuf2, sizeof(rbuf2)),
+ SyscallSucceedsWithValue(sizeof(rbuf2)));
+ EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt);
+ ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)),
+ SyscallSucceedsWithValue(sizeof(rbuf1)));
+ EXPECT_EQ(rbuf1.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, Cloexec) {
+ // Exec tests confirm that O_CLOEXEC has the intended effect. We just create a
+ // signalfd with the appropriate flag here and assert that the FD has it set.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+ EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST(Signalfd, Blocking) {
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared tid variable.
+ absl::Mutex mu;
+ bool has_tid;
+ pid_t tid;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Copy the tid and notify the caller.
+ {
+ absl::MutexLock ml(&mu);
+ tid = gettid();
+ has_tid = true;
+ }
+
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ });
+
+ // Wait until blocked.
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&has_tid));
+
+ // Deliver the signal to either the waiting thread, or
+ // to this thread. N.B. this is a bug in the core gVisor
+ // behavior for signalfd, and needs to be fixed.
+ //
+ // See gvisor.dev/issue/139.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ } else {
+ ASSERT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds());
+ }
+
+ // Ensure that it was received.
+ t.Join();
+}
+
+TEST(Signalfd, ThreadGroup) {
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared variable.
+ absl::Mutex mu;
+ bool first = false;
+ bool second = false;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Wait for the other thread.
+ absl::MutexLock ml(&mu);
+ first = true;
+ mu.Await(absl::Condition(&second));
+ });
+
+ // Deliver the signal to the threadgroup.
+ ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+
+ // Wait for the first thread to process.
+ {
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&first));
+ }
+
+ // Deliver to the thread group again (other thread still exists).
+ ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+
+ // Ensure that we can also receive it.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Mark the test as done.
+ {
+ absl::MutexLock ml(&mu);
+ second = true;
+ }
+
+ // The other thread should be joinable.
+ t.Join();
+}
+
+TEST(Signalfd, Nonblock) {
+ // Create the signalfd in non-blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // We should return if we attempt to read.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Block and deliver the signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // Ensure that a read actually works.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Should block again.
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(Signalfd, SetMask) {
+ // Create the signalfd matching nothing.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // Block and deliver a signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // We should have nothing.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Change the signal mask.
+ sigaddset(&mask, kSigno);
+ ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds());
+
+ // We should now have the signal.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, Poll) {
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Block the signal, and start a thread to deliver it.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ pid_t orig_tid = gettid();
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Seconds(5));
+ ASSERT_THAT(tgkill(getpid(), orig_tid, kSigno), SyscallSucceeds());
+ });
+
+ // Start polling for the signal. We expect that it is not available at the
+ // outset, but then becomes available when the signal is sent. We give a
+ // timeout of 10000ms (or the delay above + 5 seconds of additional grace
+ // time).
+ struct pollfd poll_fd = {fd.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ // Actually read the signal to prevent delivery.
+ struct signalfd_siginfo rbuf;
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+}
+
+TEST(Signalfd, KillStillKills) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+
+ // Just because there is a signalfd, we shouldn't see any change in behavior
+ // for unblockable signals. It's easier to test this with SIGKILL.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL));
+ EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), "");
+}
+
+TEST(Signalfd, Ppoll) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+
+ // Ensure that the given ppoll blocks.
+ struct pollfd pfd = {};
+ pfd.fd = fd.get();
+ pfd.events = POLLIN;
+ struct timespec timeout = {};
+ timeout.tv_sec = 1;
+ EXPECT_THAT(RetryEINTR(ppoll)(&pfd, 1, &timeout, &mask),
+ SyscallSucceedsWithValue(0));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // These tests depend on delivering signals. Block them up front so that all
+ // other threads created by TestInit will also have them blocked, and they
+ // will not interface with the rest of the test.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, gvisor::testing::kSigno);
+ sigaddset(&set, gvisor::testing::kSignoAlt);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc
index 9c7210e17..7db57d968 100644
--- a/test/syscalls/linux/sigstop.cc
+++ b/test/syscalls/linux/sigstop.cc
@@ -17,6 +17,7 @@
#include <sys/select.h>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/multiprocess_util.h"
@@ -24,8 +25,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(sigstop_test_child, false,
- "If true, run the SigstopTest child workload.");
+ABSL_FLAG(bool, sigstop_test_child, false,
+ "If true, run the SigstopTest child workload.");
namespace gvisor {
namespace testing {
@@ -141,7 +142,7 @@ void RunChild() {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_sigstop_test_child) {
+ if (absl::GetFlag(FLAGS_sigstop_test_child)) {
gvisor::testing::RunChild();
return 1;
}
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index 0404190a0..3a07ac8d2 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -30,12 +31,25 @@ TEST(SocketTest, UnixSocketPairProtocol) {
close(socks[1]);
}
-TEST(SocketTest, Protocol) {
+TEST(SocketTest, ProtocolUnix) {
struct {
int domain, type, protocol;
} tests[] = {
- {AF_UNIX, SOCK_STREAM, PF_UNIX}, {AF_UNIX, SOCK_SEQPACKET, PF_UNIX},
- {AF_UNIX, SOCK_DGRAM, PF_UNIX}, {AF_INET, SOCK_DGRAM, IPPROTO_UDP},
+ {AF_UNIX, SOCK_STREAM, PF_UNIX},
+ {AF_UNIX, SOCK_SEQPACKET, PF_UNIX},
+ {AF_UNIX, SOCK_DGRAM, PF_UNIX},
+ };
+ for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(tests[i].domain, tests[i].type, tests[i].protocol));
+ }
+}
+
+TEST(SocketTest, ProtocolInet) {
+ struct {
+ int domain, type, protocol;
+ } tests[] = {
+ {AF_INET, SOCK_DGRAM, IPPROTO_UDP},
{AF_INET, SOCK_STREAM, IPPROTO_TCP},
};
for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
@@ -44,5 +58,28 @@ TEST(SocketTest, Protocol) {
}
}
+using SocketOpenTest = ::testing::TestWithParam<int>;
+
+// UDS cannot be opened.
+TEST_P(SocketOpenTest, Unix) {
+ // FIXME(b/142001530): Open incorrectly succeeds on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor bound =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX));
+
+ ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(open(addr.sun_path, GetParam()), SyscallFailsWithErrno(ENXIO));
+}
+
+INSTANTIATE_TEST_SUITE_P(OpenModes, SocketOpenTest,
+ ::testing::Values(O_RDONLY, O_RDWR));
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc
new file mode 100644
index 000000000..d20821cac
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device.cc
@@ -0,0 +1,314 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+// Test fixture for SO_BINDTODEVICE tests.
+class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+
+ interface_name_ = "eth1";
+ auto interface_names = GetInterfaceNames();
+ if (interface_names.find(interface_name_) == interface_names.end()) {
+ // Need a tunnel.
+ tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New());
+ interface_name_ = tunnel_->GetName();
+ ASSERT_FALSE(interface_name_.empty());
+ }
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create());
+ }
+
+ string interface_name() const { return interface_name_; }
+
+ int socket_fd() const { return socket_->get(); }
+
+ private:
+ std::unique_ptr<Tunnel> tunnel_;
+ string interface_name_;
+ std::unique_ptr<FileDescriptor> socket_;
+};
+
+constexpr char kIllegalIfnameChar = '/';
+
+// Tests getsockopt of the default value.
+TEST_P(BindToDeviceTest, GetsockoptDefault) {
+ char name_buffer[IFNAMSIZ * 2];
+ char original_name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Read the default SO_BINDTODEVICE.
+ memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(name_buffer_size, 0);
+ EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)),
+ 0);
+ }
+}
+
+// Tests setsockopt of invalid device name.
+TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Set an invalid device name.
+ memset(name_buffer, kIllegalIfnameChar, 5);
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// Tests setsockopt of a buffer with a valid device name but not
+// null-terminated, with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Intentionally overwrite the null at the end.
+ memset(name_buffer + interface_name().size(), kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size());
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is exactly right.
+ if (name_buffer_size == interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests setsockopt of a buffer with a valid device name and null-terminated,
+// with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Don't overwrite the null at the end.
+ memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size() - 1);
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is at least the right size.
+ if (name_buffer_size >= interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests that setsockopt of an invalid device name doesn't unset the previous
+// valid setsockopt.
+TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Write unsuccessfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallFailsWithErrno(ENODEV));
+
+ // Read it back successfully, it's unchanged.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+}
+
+// Tests that setsockopt of zero-length string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClear) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ name_buffer_size = 0;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests that setsockopt of empty string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer[0] = 0;
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests getsockopt with different buffer sizes.
+TEST_P(BindToDeviceTest, GetsockoptDevice) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back at various buffer sizes.
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice
+ // for this interface name.
+ if (name_buffer_size >= IFNAMSIZ) {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+ } else {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_EQ(name_buffer_size, i);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
new file mode 100644
index 000000000..4d2400328
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -0,0 +1,381 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <atomic>
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+struct EndpointConfig {
+ std::string bind_to_device;
+ double expected_ratio;
+};
+
+struct DistributionTestCase {
+ std::string name;
+ std::vector<EndpointConfig> endpoints;
+};
+
+struct ListenerConnector {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+// Test fixture for SO_BINDTODEVICE tests the distribution of packets received
+// with varying SO_BINDTODEVICE settings.
+class BindToDeviceDistributionTest
+ : public ::testing::TestWithParam<
+ ::testing::tuple<ListenerConnector, DistributionTestCase>> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s, listener=%s, connector=%s\n",
+ ::testing::get<1>(GetParam()).name.c_str(),
+ ::testing::get<0>(GetParam()).listener.description.c_str(),
+ ::testing::get<0>(GetParam()).connector.description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ }
+};
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+// Binds sockets to different devices and then creates many TCP connections.
+// Checks that the distribution of connections received on the sockets matches
+// the expectation.
+TEST_P(BindToDeviceDistributionTest, Tcp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening sockets.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> accept_counts(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> listen_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ listen_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, &connects_received, i,
+ kConnectAttempts]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ // Another thread has shutdown our read side causing the accept to
+ // fail.
+ ASSERT_GE(connects_received, kConnectAttempts)
+ << "errno = " << fd.error();
+ return;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ int data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& listen_thread : listen_threads) {
+ listen_thread->Join();
+ }
+ // Check that connections are distributed correctly among listening sockets.
+ for (int i = 0; i < accept_counts.size(); i++) {
+ EXPECT_THAT(
+ accept_counts[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+// Binds sockets to different devices and then sends many UDP packets. Checks
+// that the distribution of packets received on the sockets matches the
+// expectation.
+TEST_P(BindToDeviceDistributionTest, Udp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening socket.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> packets_per_socket(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ receiver_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, &packets_received, i]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& receiver_thread : receiver_threads) {
+ receiver_thread->Join();
+ }
+ // Check that packets are distributed correctly among listening sockets.
+ for (int i = 0; i < packets_per_socket.size(); i++) {
+ EXPECT_THAT(
+ packets_per_socket[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+std::vector<DistributionTestCase> GetDistributionTestCases() {
+ return std::vector<DistributionTestCase>{
+ {"Even distribution among sockets not bound to device",
+ {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}},
+ {"Sockets bound to other interfaces get no packets",
+ {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}},
+ {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}},
+ {"Even distribution among sockets bound to device",
+ {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}},
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BindToDeviceTest, BindToDeviceDistributionTest,
+ ::testing::Combine(::testing::Values(
+ // Listeners bound to IPv4 addresses refuse
+ // connections using IPv6 addresses.
+ ListenerConnector{V4Any(), V4Loopback()},
+ ListenerConnector{V4Loopback(), V4MappedLoopback()}),
+ ::testing::ValuesIn(GetDistributionTestCases())));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
new file mode 100644
index 000000000..a7365d139
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -0,0 +1,316 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket
+// binding.
+class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ socket_factory_ = GetParam();
+
+ interface_names_ = GetInterfaceNames();
+ }
+
+ PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const {
+ return socket_factory_.Create();
+ }
+
+ // Gets a device by device_id. If the device_id has been seen before, returns
+ // the previously returned device. If not, finds or creates a new device.
+ // Returns an empty string on failure.
+ void GetDevice(int device_id, string *device_name) {
+ auto device = devices_.find(device_id);
+ if (device != devices_.end()) {
+ *device_name = device->second;
+ return;
+ }
+
+ // Need to pick a new device. Try ethernet first.
+ *device_name = absl::StrCat("eth", next_unused_eth_);
+ if (interface_names_.find(*device_name) != interface_names_.end()) {
+ devices_[device_id] = *device_name;
+ next_unused_eth_++;
+ return;
+ }
+
+ // Need to make a new tunnel device. gVisor tests should have enough
+ // ethernet devices to never reach here.
+ ASSERT_FALSE(IsRunningOnGvisor());
+ // Need a tunnel.
+ tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()));
+ devices_[device_id] = tunnels_.back()->GetName();
+ *device_name = devices_[device_id];
+ }
+
+ // Release the socket
+ void ReleaseSocket(int socket_id) {
+ // Close the socket that was made in a previous action. The socket_id
+ // indicates which socket to close based on index into the list of actions.
+ sockets_to_close_.erase(socket_id);
+ }
+
+ // Bind a socket with the reuse option and bind_to_device options. Checks
+ // that all steps succeed and that the bind command's error matches want.
+ // Sets the socket_id to uniquely identify the socket bound if it is not
+ // nullptr.
+ void BindSocket(bool reuse, int device_id = 0, int want = 0,
+ int *socket_id = nullptr) {
+ next_socket_id_++;
+ sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket_fd = sockets_to_close_[next_socket_id_]->get();
+ if (socket_id != nullptr) {
+ *socket_id = next_socket_id_;
+ }
+
+ // If reuse is indicated, do that.
+ if (reuse) {
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // If the device is non-zero, bind to that device.
+ if (device_id != 0) {
+ string device_name;
+ ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name));
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE,
+ device_name.c_str(), device_name.size() + 1),
+ SyscallSucceedsWithValue(0));
+ char get_device[100];
+ socklen_t get_device_size = 100;
+ EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device,
+ &get_device_size),
+ SyscallSucceedsWithValue(0));
+ }
+
+ struct sockaddr_in addr = {};
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr.sin_port = port_;
+ if (want == 0) {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ } else {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(want));
+ }
+
+ if (port_ == 0) {
+ // We don't yet know what port we'll be using so we need to fetch it and
+ // remember it for future commands.
+ socklen_t addr_size = sizeof(addr);
+ ASSERT_THAT(
+ getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr),
+ &addr_size),
+ SyscallSucceeds());
+ port_ = addr.sin_port;
+ }
+ }
+
+ private:
+ SocketKind socket_factory_;
+ // devices maps from the device id in the test case to the name of the device.
+ std::unordered_map<int, string> devices_;
+ // These are the tunnels that were created for the test and will be destroyed
+ // by the destructor.
+ vector<std::unique_ptr<Tunnel>> tunnels_;
+ // A list of all interface names before the test started.
+ std::unordered_set<string> interface_names_;
+ // The next ethernet device to use when requested a device.
+ int next_unused_eth_ = 1;
+ // The port for all tests. Originally 0 (any) and later set to the port that
+ // all further commands will use.
+ in_port_t port_ = 0;
+ // sockets_to_close_ is a map from action index to the socket that was
+ // created.
+ std::unordered_map<int,
+ std::unique_ptr<gvisor::testing::FileDescriptor>>
+ sockets_to_close_;
+ int next_socket_id_ = 0;
+};
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 1));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 2));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithReuse) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindAndRelease) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ int to_release;
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ // Release the bind to device 0 and try again.
+ ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc
new file mode 100644
index 000000000..f4ee775bd
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.cc
@@ -0,0 +1,75 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) {
+ int fd;
+ RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR));
+
+ // Using `new` to access a non-public constructor.
+ auto new_tunnel = absl::WrapUnique(new Tunnel(fd));
+
+ ifreq ifr = {};
+ ifr.ifr_flags = IFF_TUN;
+ strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr));
+ new_tunnel->name_ = ifr.ifr_name;
+ return new_tunnel;
+}
+
+std::unordered_set<string> GetInterfaceNames() {
+ struct if_nameindex* interfaces = if_nameindex();
+ std::unordered_set<string> names;
+ if (interfaces == nullptr) {
+ return names;
+ }
+ for (auto interface = interfaces;
+ interface->if_index != 0 || interface->if_name != nullptr; interface++) {
+ names.insert(interface->if_name);
+ }
+ if_freenameindex(interfaces);
+ return names;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h
new file mode 100644
index 000000000..f941ccc86
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.h
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+class Tunnel {
+ public:
+ static PosixErrorOr<std::unique_ptr<Tunnel>> New(
+ std::string tunnel_name = "");
+ const std::string& GetName() const { return name_; }
+
+ ~Tunnel() {
+ if (fd_ != -1) {
+ close(fd_);
+ }
+ }
+
+ private:
+ Tunnel(int fd) : fd_(fd) {}
+ int fd_ = -1;
+ std::string name_;
+};
+
+std::unordered_set<std::string> GetInterfaceNames();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index a43cf9bce..bfa7943b1 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -117,7 +117,7 @@ TEST_P(TCPSocketPairTest, RSTCausesPollHUP) {
struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0};
ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs),
SyscallSucceedsWithValue(1));
- ASSERT_NE(poll_fd.revents & (POLLHUP | POLLIN), 0);
+ ASSERT_NE(poll_fd3.revents & POLLHUP, 0);
}
// This test validates that even if a RST is sent the other end will not
diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc
new file mode 100644
index 000000000..fa9a9df6f
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_unbound.cc
@@ -0,0 +1,379 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of IP sockets.
+using IPUnboundSocketTest = SimpleSocketTest;
+
+TEST_P(IPUnboundSocketTest, TtlDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, 64);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+TEST_P(IPUnboundSocketTest, SetTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get1 = -1;
+ socklen_t get1_sz = sizeof(get1);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get1_sz, sizeof(get1));
+
+ int set = 100;
+ if (set == get1) {
+ set += 1;
+ }
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+
+ int get2 = -1;
+ socklen_t get2_sz = sizeof(get2);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get2_sz, sizeof(get2));
+ EXPECT_EQ(get2, set);
+}
+
+TEST_P(IPUnboundSocketTest, ResetTtlToDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get1 = -1;
+ socklen_t get1_sz = sizeof(get1);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get1_sz, sizeof(get1));
+
+ int set1 = 100;
+ if (set1 == get1) {
+ set1 += 1;
+ }
+ socklen_t set1_sz = sizeof(set1);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set1, set1_sz),
+ SyscallSucceedsWithValue(0));
+
+ int set2 = -1;
+ socklen_t set2_sz = sizeof(set2);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set2, set2_sz),
+ SyscallSucceedsWithValue(0));
+
+ int get2 = -1;
+ socklen_t get2_sz = sizeof(get2);
+ EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get2_sz, sizeof(get2));
+ EXPECT_EQ(get2, get1);
+}
+
+TEST_P(IPUnboundSocketTest, ZeroTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = 0;
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPUnboundSocketTest, InvalidLargeTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = 256;
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = -2;
+ socklen_t set_sz = sizeof(set);
+ EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+struct TOSOption {
+ int level;
+ int option;
+};
+
+constexpr int INET_ECN_MASK = 3;
+
+static TOSOption GetTOSOption(int domain) {
+ TOSOption opt;
+ switch (domain) {
+ case AF_INET:
+ opt.level = IPPROTO_IP;
+ opt.option = IP_TOS;
+ break;
+ case AF_INET6:
+ opt.level = IPPROTO_IPV6;
+ opt.option = IPV6_TCLASS;
+ break;
+ }
+ return opt;
+}
+
+TEST_P(IPUnboundSocketTest, TOSDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ constexpr int kDefaultTOS = 0;
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, kDefaultTOS);
+}
+
+TEST_P(IPUnboundSocketTest, SetTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, set);
+}
+
+TEST_P(IPUnboundSocketTest, ZeroTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, set);
+}
+
+TEST_P(IPUnboundSocketTest, InvalidLargeTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ // Test with exceeding the byte space.
+ int set = 256;
+ constexpr int kDefaultTOS = 0;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ } else {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+ }
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, kDefaultTOS);
+}
+
+TEST_P(IPUnboundSocketTest, CheckSkipECN) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xFF;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ int expect = static_cast<uint8_t>(set);
+ if (GetParam().protocol == IPPROTO_TCP) {
+ expect &= ~INET_ECN_MASK;
+ }
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, expect);
+}
+
+TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ socklen_t set_sz = 0;
+ TOSOption t = GetTOSOption(GetParam().domain);
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ } else {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+ }
+ int get = -1;
+ socklen_t get_sz = 0;
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, 0);
+ EXPECT_EQ(get, -1);
+}
+
+TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ constexpr int kDefaultTOS = 0;
+ TOSOption t = GetTOSOption(GetParam().domain);
+ for (socklen_t i = 1; i < sizeof(int); i++) {
+ int expect_tos;
+ socklen_t expect_sz;
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i),
+ SyscallSucceedsWithValue(0));
+ expect_tos = set;
+ expect_sz = sizeof(uint8_t);
+ } else {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i),
+ SyscallFailsWithErrno(EINVAL));
+ expect_tos = kDefaultTOS;
+ expect_sz = i;
+ }
+ uint get = -1;
+ socklen_t get_sz = i;
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, expect_sz);
+ // Account for partial copies by getsockopt, retrieve the lower
+ // bits specified by get_sz, while comparing against expect_tos.
+ EXPECT_EQ(get & ~(~0 << (get_sz * 8)), expect_tos);
+ }
+}
+
+TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = 0xC0;
+ TOSOption t = GetTOSOption(GetParam().domain);
+ for (socklen_t i = sizeof(int); i < 10; i++) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i),
+ SyscallSucceedsWithValue(0));
+ int get = -1;
+ socklen_t get_sz = i;
+ // We expect the system call handler to only copy atmost sizeof(int) bytes
+ // as asserted by the check below. Hence, we do not expect the copy to
+ // overflow in getsockopt.
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(int));
+ EXPECT_EQ(get, set);
+ }
+}
+
+TEST_P(IPUnboundSocketTest, NegativeTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int set = -1;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ int expect;
+ if (GetParam().domain == AF_INET) {
+ expect = static_cast<uint8_t>(set);
+ if (GetParam().protocol == IPPROTO_TCP) {
+ expect &= ~INET_ECN_MASK;
+ }
+ } else {
+ // On IPv6 TCLASS, setting -1 has the effect of resetting the
+ // TrafficClass.
+ expect = 0;
+ }
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, expect);
+}
+
+TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ int set = -2;
+ socklen_t set_sz = sizeof(set);
+ TOSOption t = GetTOSOption(GetParam().domain);
+ int expect;
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallSucceedsWithValue(0));
+ expect = static_cast<uint8_t>(set);
+ if (GetParam().protocol == IPPROTO_TCP) {
+ expect &= ~INET_ECN_MASK;
+ }
+ } else {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ SyscallFailsWithErrno(EINVAL));
+ expect = 0;
+ }
+ int get = 0;
+ socklen_t get_sz = sizeof(get);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_sz, sizeof(get));
+ EXPECT_EQ(get, expect);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ IPUnboundSockets, IPUnboundSocketTest,
+ ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>(
+ ApplyVec<SocketKind>(IPv4UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM},
+ List<int>{0,
+ SOCK_NONBLOCK})),
+ ApplyVec<SocketKind>(IPv6UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_DGRAM},
+ List<int>{0,
+ SOCK_NONBLOCK})),
+ ApplyVec<SocketKind>(IPv4TCPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{0,
+ SOCK_NONBLOCK})),
+ ApplyVec<SocketKind>(IPv6TCPUnboundSocket,
+ AllBitwiseCombinations(List<int>{SOCK_STREAM},
+ List<int>{
+ 0, SOCK_NONBLOCK}))))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index c85ae30dc..8b8993d3d 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -42,6 +42,26 @@ TestAddress V4EmptyAddress() {
return t;
}
+constexpr char kMulticastAddress[] = "224.0.2.1";
+
+TestAddress V4Multicast() {
+ TestAddress t("V4Multicast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ inet_addr(kMulticastAddress);
+ return t;
+}
+
+TestAddress V4Broadcast() {
+ TestAddress t("V4Broadcast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ htonl(INADDR_BROADCAST);
+ return t;
+}
+
void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
got_if_infos_ = false;
@@ -116,7 +136,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) {
// Verifies that a broadcast UDP packet will arrive at all UDP sockets with
// the destination port number.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
- UDPBroadcastReceivedOnAllExpectedEndpoints) {
+ UDPBroadcastReceivedOnExpectedPort) {
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -136,51 +156,134 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
sizeof(kSockOptOn)),
SyscallSucceedsWithValue(0));
- sockaddr_in rcv_addr = {};
- socklen_t rcv_addr_sz = sizeof(rcv_addr);
- rcv_addr.sin_family = AF_INET;
- rcv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
- ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr),
- rcv_addr_sz),
+ // Bind the first socket to the ANY address and let the system assign a port.
+ auto rcv1_addr = V4Any();
+ ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ rcv1_addr.addr_len),
SyscallSucceedsWithValue(0));
// Retrieve port number from first socket so that it can be bound to the
// second socket.
- rcv_addr = {};
+ socklen_t rcv_addr_sz = rcv1_addr.addr_len;
ASSERT_THAT(
- getsockname(rcvr1->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr),
+ getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
&rcv_addr_sz),
SyscallSucceedsWithValue(0));
- ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr),
+ EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len);
+ auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port;
+
+ // Bind the second socket to the same address:port as the first.
+ ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
rcv_addr_sz),
SyscallSucceedsWithValue(0));
// Bind the non-receiving socket to an ephemeral port.
- sockaddr_in norcv_addr = {};
- norcv_addr.sin_family = AF_INET;
- norcv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
+ auto norecv_addr = V4Any();
+ ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
+ norecv_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+
+ // Broadcast a test message.
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port;
+ constexpr char kTestMsg[] = "hello, world";
+ EXPECT_THAT(
+ sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+
+ // Verify that the receiving sockets received the test message.
+ char buf[sizeof(kTestMsg)] = {};
+ EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+ memset(buf, 0, sizeof(buf));
+ EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+
+ // Verify that the non-receiving socket did not receive the test message.
+ memset(buf, 0, sizeof(buf));
+ EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Verifies that a broadcast UDP packet will arrive at all UDP sockets bound to
+// the destination port number and either INADDR_ANY or INADDR_BROADCAST, but
+// not a unicast address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ UDPBroadcastReceivedOnExpectedAddresses) {
+ // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
+ // IPv4 address on eth0.
+ SKIP_IF(!got_if_infos_);
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Enable SO_BROADCAST on the sending socket.
+ ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Enable SO_REUSEPORT on all sockets so that they may all be bound to the
+ // broadcast messages destination port.
+ ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(setsockopt(norcv->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the first socket the ANY address and let the system assign a port.
+ auto rcv1_addr = V4Any();
+ ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ rcv1_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+ // Retrieve port number from first socket so that it can be bound to the
+ // second socket.
+ socklen_t rcv_addr_sz = rcv1_addr.addr_len;
ASSERT_THAT(
- bind(norcv->get(), reinterpret_cast<struct sockaddr*>(&norcv_addr),
- sizeof(norcv_addr)),
+ getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
+ &rcv_addr_sz),
SyscallSucceedsWithValue(0));
+ EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len);
+ auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port;
+
+ // Bind the second socket to the broadcast address.
+ auto rcv2_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&rcv2_addr.addr)->sin_port = port;
+ ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv2_addr.addr),
+ rcv2_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the non-receiving socket to the unicast ethernet address.
+ auto norecv_addr = rcv1_addr;
+ reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr =
+ eth_if_sin_addr_;
+ ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
+ norecv_addr.addr_len),
+ SyscallSucceedsWithValue(0));
// Broadcast a test message.
- sockaddr_in dst_addr = {};
- dst_addr.sin_family = AF_INET;
- dst_addr.sin_addr.s_addr = htonl(INADDR_BROADCAST);
- dst_addr.sin_port = rcv_addr.sin_port;
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port;
constexpr char kTestMsg[] = "hello, world";
EXPECT_THAT(
sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<struct sockaddr*>(&dst_addr), sizeof(dst_addr)),
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
SyscallSucceedsWithValue(sizeof(kTestMsg)));
// Verify that the receiving sockets received the test message.
char buf[sizeof(kTestMsg)] = {};
- EXPECT_THAT(read(rcvr1->get(), buf, sizeof(buf)),
+ EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0),
SyscallSucceedsWithValue(sizeof(kTestMsg)));
EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
memset(buf, 0, sizeof(buf));
- EXPECT_THAT(read(rcvr2->get(), buf, sizeof(buf)),
+ EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0),
SyscallSucceedsWithValue(sizeof(kTestMsg)));
EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
@@ -190,10 +293,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
SyscallFailsWithErrno(EAGAIN));
}
-// Verifies that a UDP broadcast sent via the loopback interface is not received
-// by the sender.
+// Verifies that a UDP broadcast can be sent and then received back on the same
+// socket that is bound to the broadcast address (255.255.255.255).
+// FIXME(b/141938460): This can be combined with the next test
+// (UDPBroadcastSendRecvOnSocketBoundToAny).
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
- UDPBroadcastViaLoopbackFails) {
+ UDPBroadcastSendRecvOnSocketBoundToBroadcast) {
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Enable SO_BROADCAST.
@@ -201,33 +306,73 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
sizeof(kSockOptOn)),
SyscallSucceedsWithValue(0));
- // Bind the sender to the loopback interface.
- sockaddr_in src = {};
- socklen_t src_sz = sizeof(src);
- src.sin_family = AF_INET;
- src.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(
- bind(sender->get(), reinterpret_cast<struct sockaddr*>(&src), src_sz),
- SyscallSucceedsWithValue(0));
+ // Bind the sender to the broadcast address.
+ auto src_addr = V4Broadcast();
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr),
+ src_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+ socklen_t src_sz = src_addr.addr_len;
ASSERT_THAT(getsockname(sender->get(),
- reinterpret_cast<struct sockaddr*>(&src), &src_sz),
+ reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz),
SyscallSucceedsWithValue(0));
- ASSERT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(src_sz, src_addr.addr_len);
// Send the message.
- sockaddr_in dst = {};
- dst.sin_family = AF_INET;
- dst.sin_addr.s_addr = htonl(INADDR_BROADCAST);
- dst.sin_port = src.sin_port;
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port;
constexpr char kTestMsg[] = "hello, world";
- EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<struct sockaddr*>(&dst), sizeof(dst)),
+ EXPECT_THAT(
+ sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+
+ // Verify that the message was received.
+ char buf[sizeof(kTestMsg)] = {};
+ EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0),
SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
+}
+
+// Verifies that a UDP broadcast can be sent and then received back on the same
+// socket that is bound to the ANY address (0.0.0.0).
+// FIXME(b/141938460): This can be combined with the previous test
+// (UDPBroadcastSendRecvOnSocketBoundToBroadcast).
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ UDPBroadcastSendRecvOnSocketBoundToAny) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- // Verify that the message was not received by the sender (loopback).
+ // Enable SO_BROADCAST.
+ ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ // Bind the sender to the ANY address.
+ auto src_addr = V4Any();
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr),
+ src_addr.addr_len),
+ SyscallSucceedsWithValue(0));
+ socklen_t src_sz = src_addr.addr_len;
+ ASSERT_THAT(getsockname(sender->get(),
+ reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(src_sz, src_addr.addr_len);
+
+ // Send the message.
+ auto dst_addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port;
+ constexpr char kTestMsg[] = "hello, world";
+ EXPECT_THAT(
+ sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+
+ // Verify that the message was received.
char buf[sizeof(kTestMsg)] = {};
- EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg)));
}
// Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST
@@ -237,15 +382,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) {
// Broadcast a test message without having enabled SO_BROADCAST on the sending
// socket.
- sockaddr_in addr = {};
- socklen_t addr_sz = sizeof(addr);
- addr.sin_family = AF_INET;
- addr.sin_port = htons(12345);
- addr.sin_addr.s_addr = htonl(INADDR_BROADCAST);
+ auto addr = V4Broadcast();
+ reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(12345);
constexpr char kTestMsg[] = "hello, world";
EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<struct sockaddr*>(&addr), addr_sz),
+ reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(EACCES));
}
@@ -274,21 +416,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) {
reinterpret_cast<struct sockaddr*>(&addr), addr_sz),
SyscallSucceedsWithValue(sizeof(kTestMsg)));
char buf[sizeof(kTestMsg)] = {};
- ASSERT_THAT(read(rcvr->get(), buf, sizeof(buf)),
+ ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0),
SyscallSucceedsWithValue(sizeof(kTestMsg)));
}
-constexpr char kMulticastAddress[] = "224.0.2.1";
-
-TestAddress V4Multicast() {
- TestAddress t("V4Multicast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- inet_addr(kMulticastAddress);
- return t;
-}
-
// Check that multicast packets won't be delivered to the sending socket with no
// set interface or group membership.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
@@ -609,8 +740,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
}
// Check that two sockets can join the same multicast group at the same time,
-// and both will receive data on it.
-TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) {
+// and both will receive data on it when bound to the ANY address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastToTwoBoundToAny) {
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
std::unique_ptr<FileDescriptor> receivers[2] = {
ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
@@ -624,8 +756,72 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) {
ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- // Bind the receiver to the v4 any address to ensure that we can receive the
- // multicast packet.
+ // Bind to ANY to receive multicast packets.
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ EXPECT_EQ(
+ htonl(INADDR_ANY),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ // On the first iteration, save the port we are bound to. On the second
+ // iteration, verify the port is the same as the one from the first
+ // iteration. In other words, both sockets listen on the same port.
+ if (bound_port == 0) {
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ } else {
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+
+ // Register to receive multicast packets.
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ }
+
+ // Send a multicast packet to the group and verify both receivers get it.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ for (auto& receiver : receivers) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+}
+
+// Check that two sockets can join the same multicast group at the same time,
+// and both will receive data on it when bound to the multicast address.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastToTwoBoundToMulticastAddress) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ std::unique_ptr<FileDescriptor> receivers[2] = {
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())};
+
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ auto receiver_addr = V4Multicast();
+ int bound_port = 0;
+ for (auto& receiver : receivers) {
+ ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
ASSERT_THAT(
bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
receiver_addr.addr_len),
@@ -636,6 +832,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) {
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ EXPECT_EQ(
+ inet_addr(kMulticastAddress),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
// On the first iteration, save the port we are bound to. On the second
// iteration, verify the port is the same as the one from the first
// iteration. In other words, both sockets listen on the same port.
@@ -643,6 +842,83 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) {
bound_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
} else {
+ EXPECT_EQ(
+ inet_addr(kMulticastAddress),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+
+ // Register to receive multicast packets.
+ ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ }
+
+ // Send a multicast packet to the group and verify both receivers get it.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ for (auto& receiver : receivers) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+}
+
+// Check that two sockets can join the same multicast group at the same time,
+// and with one bound to the wildcard address and the other bound to the
+// multicast address, both will receive data.
+TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
+ TestSendMulticastToTwoBoundToAnyAndMulticastAddress) {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ std::unique_ptr<FileDescriptor> receivers[2] = {
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())};
+
+ ip_mreq group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ // The first receiver binds to the wildcard address.
+ auto receiver_addr = V4Any();
+ int bound_port = 0;
+ for (auto& receiver : receivers) {
+ ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ // On the first iteration, save the port we are bound to and change the
+ // receiver address from V4Any to V4Multicast so the second receiver binds
+ // to that. On the second iteration, verify the port is the same as the one
+ // from the first iteration but the address is different.
+ if (bound_port == 0) {
+ EXPECT_EQ(
+ htonl(INADDR_ANY),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ receiver_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port =
+ bound_port;
+ } else {
+ EXPECT_EQ(
+ inet_addr(kMulticastAddress),
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr);
EXPECT_EQ(bound_port,
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
}
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index 32fe0d6d1..dd4a11655 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -539,6 +539,159 @@ TEST(NetlinkRouteTest, GetRouteDump) {
EXPECT_TRUE(dstFound);
}
+// RecvmsgTrunc tests the recvmsg MSG_TRUNC flag with zero length output
+// buffer. MSG_TRUNC with a zero length buffer should consume subsequent
+// messages off the socket.
+TEST(NetlinkRouteTest, RecvmsgTrunc) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ int trunclen, trunclen2;
+
+ // Note: This test assumes at least two messages are returned by the
+ // RTM_GETADDR request. That means at least one RTM_NEWLINK message and one
+ // NLMSG_DONE message. We cannot read all the messages without blocking
+ // because we would need to read the message into a buffer and check the
+ // nlmsg_type for NLMSG_DONE. However, the test depends on reading into a
+ // zero-length buffer.
+
+ // First, call recvmsg with MSG_TRUNC. This will read the full message from
+ // the socket and return it's full length. Subsequent calls to recvmsg will
+ // read the next messages from the socket.
+ ASSERT_THAT(trunclen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC),
+ SyscallSucceeds());
+
+ // Message should always be truncated. However, While the destination iov is
+ // zero length, MSG_TRUNC returns the size of the next message so it should
+ // not be zero.
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+ ASSERT_NE(trunclen, 0);
+ // Returned length is at least the header and ifaddrmsg.
+ EXPECT_GE(trunclen, sizeof(struct nlmsghdr) + sizeof(struct ifaddrmsg));
+
+ // Reset the msg_flags to make sure that the recvmsg call is setting them
+ // properly.
+ msg.msg_flags = 0;
+
+ // Make a second recvvmsg call to get the next message.
+ ASSERT_THAT(trunclen2 = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC),
+ SyscallSucceeds());
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+ ASSERT_NE(trunclen2, 0);
+
+ // Assert that the received messages are not the same.
+ //
+ // We are calling recvmsg with a zero length buffer so we have no way to
+ // inspect the messages to make sure they are not equal in value. The best
+ // we can do is to compare their lengths.
+ ASSERT_NE(trunclen, trunclen2);
+}
+
+// RecvmsgTruncPeek tests recvmsg with the combination of the MSG_TRUNC and
+// MSG_PEEK flags and a zero length output buffer. This is normally used to
+// read the full length of the next message on the socket without consuming
+// it, so a properly sized buffer can be allocated to store the message. This
+// test tests that scenario.
+TEST(NetlinkRouteTest, RecvmsgTruncPeek) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ int type = -1;
+ do {
+ int peeklen;
+ int len;
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ // Call recvmsg with MSG_PEEK and MSG_TRUNC. This will peek at the message
+ // and return it's full length.
+ // See: MSG_TRUNC http://man7.org/linux/man-pages/man2/recv.2.html
+ ASSERT_THAT(
+ peeklen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_PEEK | MSG_TRUNC),
+ SyscallSucceeds());
+
+ // Message should always be truncated.
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+ ASSERT_NE(peeklen, 0);
+
+ // Reset the message flags for the next call.
+ msg.msg_flags = 0;
+
+ // Make the actual call to recvmsg to get the actual data. We will use
+ // the length returned from the peek call for the allocated buffer size..
+ std::vector<char> buf(peeklen);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+ ASSERT_THAT(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0),
+ SyscallSucceeds());
+
+ // Message should not be truncated since we allocated the correct buffer
+ // size.
+ EXPECT_NE(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+
+ // MSG_PEEK should have left data on the socket and the subsequent call
+ // with should have retrieved the same data. Both calls should have
+ // returned the message's full length so they should be equal.
+ ASSERT_NE(len, 0);
+ ASSERT_EQ(peeklen, len);
+
+ for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
+ NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
+ type = hdr->nlmsg_type;
+ }
+ } while (type != NLMSG_DONE && type != NLMSG_ERROR);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc
index 36b6560c2..fcb8f8a88 100644
--- a/test/syscalls/linux/socket_netlink_util.cc
+++ b/test/syscalls/linux/socket_netlink_util.cc
@@ -91,6 +91,13 @@ PosixError NetlinkRequestResponse(
NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
fn(hdr);
type = hdr->nlmsg_type;
+ // Done should include an integer payload for dump_done_errno.
+ // See net/netlink/af_netlink.c:netlink_dump
+ // Some tools like the 'ip' tool check the minimum length of the
+ // NLMSG_DONE message.
+ if (type == NLMSG_DONE) {
+ EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int)));
+ }
}
} while (type != NLMSG_DONE && type != NLMSG_ERROR);
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 3c716235b..eff7d577e 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -588,8 +588,9 @@ ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets,
return RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0);
}
-PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
- bool reuse_addr) {
+namespace internal {
+PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
+ SocketType type, bool reuse_addr) {
if (port < 0) {
return PosixError(EINVAL, "Invalid port");
}
@@ -664,10 +665,7 @@ PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
return available_port;
}
-
-PosixError FreeAvailablePort(int port) {
- return NoError();
-}
+} // namespace internal
PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) {
struct iovec iov;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index ae0da2679..70710195c 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -83,6 +83,8 @@ inline ssize_t SendFd(int fd, void* buf, size_t count, int flags) {
count);
}
+PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain);
+
// A Creator<T> is a function that attempts to create and return a new T. (This
// is copy/pasted from cloud/gvisor/api/sandbox_util.h and is just duplicated
// here for clarity.)
@@ -492,6 +494,11 @@ uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
ssize_t payload_len);
+namespace internal {
+PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
+ SocketType type, bool reuse_addr);
+} // namespace internal
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_test_util_impl.cc b/test/syscalls/linux/socket_test_util_impl.cc
new file mode 100644
index 000000000..ef661a0e3
--- /dev/null
+++ b/test/syscalls/linux/socket_test_util_impl.cc
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
+ bool reuse_addr) {
+ return internal::TryPortAvailable(port, family, type, reuse_addr);
+}
+
+PosixError FreeAvailablePort(int port) { return NoError(); }
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index be661c2b6..8f38ed92f 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <poll.h>
#include <stdio.h>
#include <sys/un.h>
-#include <poll.h>
+
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
@@ -52,9 +53,9 @@ TEST_P(StreamUnixSocketPairTest, RecvmsgOneSideClosed) {
struct timeval tv {
.tv_sec = 0, .tv_usec = 10
};
- EXPECT_THAT(
- setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
- SyscallSucceeds());
+ EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv,
+ sizeof(tv)),
+ SyscallSucceeds());
ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
@@ -78,8 +79,7 @@ TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) {
ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
- ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR),
- SyscallSucceeds());
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
SyscallSucceedsWithValue(0));
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index e25f264f6..85232cb1f 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -14,12 +14,16 @@
#include <fcntl.h>
#include <sys/eventfd.h>
+#include <sys/resource.h>
#include <sys/sendfile.h>
+#include <sys/time.h>
#include <unistd.h>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -36,23 +40,23 @@ TEST(SpliceTest, TwoRegularFiles) {
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
// Open the input file as read only.
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
// Open the output file as write only.
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
// Verify that it is rejected as expected; regardless of offsets.
loff_t in_offset = 0;
loff_t out_offset = 0;
- EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), &out_offset, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), &out_offset, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), nullptr, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), nullptr, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -75,8 +79,6 @@ TEST(SpliceTest, SamePipe) {
}
TEST(TeeTest, SamePipe) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create a new pipe.
int fds[2];
ASSERT_THAT(pipe(fds), SyscallSucceeds());
@@ -95,11 +97,9 @@ TEST(TeeTest, SamePipe) {
}
TEST(TeeTest, RegularFile) {
- SKIP_IF(IsRunningOnGvisor());
-
// Open some file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Create a new pipe.
@@ -109,9 +109,9 @@ TEST(TeeTest, RegularFile) {
const FileDescriptor wfd(fds[1]);
// Attempt to tee from the file.
- EXPECT_THAT(tee(inf.get(), wfd.get(), kPageSize, 0),
+ EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(tee(rfd.get(), inf.get(), kPageSize, 0),
+ EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -142,7 +142,7 @@ TEST(SpliceTest, FromEventFD) {
constexpr uint64_t kEventFDValue = 1;
int efd;
ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds());
- const FileDescriptor inf(efd);
+ const FileDescriptor in_fd(efd);
// Create a new pipe.
int fds[2];
@@ -152,7 +152,7 @@ TEST(SpliceTest, FromEventFD) {
// Splice 8-byte eventfd value to pipe.
constexpr int kEventFDSize = 8;
- EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
SyscallSucceedsWithValue(kEventFDSize));
// Contents should be equal.
@@ -166,7 +166,7 @@ TEST(SpliceTest, FromEventFD) {
TEST(SpliceTest, FromEventFDOffset) {
int efd;
ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor inf(efd);
+ const FileDescriptor in_fd(efd);
// Create a new pipe.
int fds[2];
@@ -179,7 +179,7 @@ TEST(SpliceTest, FromEventFDOffset) {
// This is not allowed because eventfd doesn't support pread.
constexpr int kEventFDSize = 8;
loff_t in_off = 0;
- EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -200,28 +200,29 @@ TEST(SpliceTest, ToEventFDOffset) {
int efd;
ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor outf(efd);
+ const FileDescriptor out_fd(efd);
// Attempt to splice 8-byte eventfd value to pipe with offset.
//
// This is not allowed because eventfd doesn't support pwrite.
loff_t out_off = 0;
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0),
- SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0),
+ SyscallFailsWithErrno(EINVAL));
}
TEST(SpliceTest, ToPipe) {
// Open the input file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Fill with some random data.
std::vector<char> buf(kPageSize);
RandomizeBuffer(buf.data(), buf.size());
- ASSERT_THAT(write(inf.get(), buf.data(), buf.size()),
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(kPageSize));
- ASSERT_THAT(lseek(inf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
// Create a new pipe.
int fds[2];
@@ -230,7 +231,7 @@ TEST(SpliceTest, ToPipe) {
const FileDescriptor wfd(fds[1]);
// Splice to the pipe.
- EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
SyscallSucceedsWithValue(kPageSize));
// Contents should be equal.
@@ -243,13 +244,13 @@ TEST(SpliceTest, ToPipe) {
TEST(SpliceTest, ToPipeOffset) {
// Open the input file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Fill with some random data.
std::vector<char> buf(kPageSize);
RandomizeBuffer(buf.data(), buf.size());
- ASSERT_THAT(write(inf.get(), buf.data(), buf.size()),
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(kPageSize));
// Create a new pipe.
@@ -261,7 +262,7 @@ TEST(SpliceTest, ToPipeOffset) {
// Splice to the pipe.
loff_t in_offset = kPageSize / 2;
EXPECT_THAT(
- splice(inf.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0),
+ splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0),
SyscallSucceedsWithValue(kPageSize / 2));
// Contents should be equal to only the second part.
@@ -286,22 +287,22 @@ TEST(SpliceTest, FromPipe) {
// Open the input file.
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
// Splice to the output file.
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), nullptr, kPageSize, 0),
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0),
SyscallSucceedsWithValue(kPageSize));
// The offset of the output should be equal to kPageSize. We assert that and
// reset to zero so that we can read the contents and ensure they match.
- EXPECT_THAT(lseek(outf.get(), 0, SEEK_CUR),
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR),
SyscallSucceedsWithValue(kPageSize));
- ASSERT_THAT(lseek(outf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
// Contents should be equal.
std::vector<char> rbuf(kPageSize);
- ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()),
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
SyscallSucceedsWithValue(kPageSize));
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
}
@@ -321,18 +322,19 @@ TEST(SpliceTest, FromPipeOffset) {
// Open the input file.
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
// Splice to the output file.
loff_t out_offset = kPageSize / 2;
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_offset, kPageSize, 0),
- SyscallSucceedsWithValue(kPageSize));
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
// Content should reflect the splice. We write to a specific offset in the
// file, so the internals should now be allocated sparsely.
std::vector<char> rbuf(kPageSize);
- ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()),
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
SyscallSucceedsWithValue(kPageSize));
std::vector<char> zbuf(kPageSize / 2);
memset(zbuf.data(), 0, zbuf.size());
@@ -404,8 +406,6 @@ TEST(SpliceTest, Blocking) {
}
TEST(TeeTest, Blocking) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create two new pipes.
int first[2], second[2];
ASSERT_THAT(pipe(first), SyscallSucceeds());
@@ -440,6 +440,49 @@ TEST(TeeTest, Blocking) {
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
}
+TEST(TeeTest, BlockingWrite) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> buf1(kPageSize);
+ RandomizeBuffer(buf1.data(), buf1.size());
+ ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Fill up the write pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf2(pipe_size);
+ ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ // Attempt a tee immediately; it should block.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Thread should be joinable.
+ t.Join();
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0);
+}
+
TEST(SpliceTest, NonBlocking) {
// Create two new pipes.
int first[2], second[2];
@@ -457,8 +500,6 @@ TEST(SpliceTest, NonBlocking) {
}
TEST(TeeTest, NonBlocking) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create two new pipes.
int first[2], second[2];
ASSERT_THAT(pipe(first), SyscallSucceeds());
@@ -473,6 +514,79 @@ TEST(TeeTest, NonBlocking) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST(TeeTest, MultiPage) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> wbuf(8 * kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Attempt a tee immediately; it should complete.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(wbuf.size());
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+ ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+}
+
+TEST(SpliceTest, FromPipeMaxFileSize) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the input file.
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor out_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
+
+ EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds());
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END),
+ SyscallSucceedsWithValue(13 << 20));
+
+ // Set our file size limit.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGXFSZ);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+ rlimit rlim = {};
+ rlim.rlim_cur = rlim.rlim_max = (13 << 20);
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds());
+
+ // Splice to the output file.
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0),
+ SyscallFailsWithErrno(EFBIG));
+
+ // Contents should be equal.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc
index 59fb5dfe6..7e73325bf 100644
--- a/test/syscalls/linux/sticky.cc
+++ b/test/syscalls/linux/sticky.cc
@@ -17,9 +17,11 @@
#include <sys/prctl.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <vector>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
@@ -27,8 +29,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "first scratch UID");
-DEFINE_int32(scratch_gid, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
namespace gvisor {
namespace testing {
@@ -52,10 +54,12 @@ TEST(StickyTest, StickyBitPermDenied) {
}
// Change EUID and EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM));
});
@@ -78,8 +82,9 @@ TEST(StickyTest, StickyBitSameUID) {
}
// Change EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
// We still have the same EUID.
EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
@@ -101,10 +106,12 @@ TEST(StickyTest, StickyBitCapFOWNER) {
EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
// Change EUID and EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true));
EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 8f4d3f386..bfa031bce 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -579,7 +579,7 @@ TEST_P(TcpSocketTest, TcpInq) {
if (size == sizeof(buf)) {
break;
}
- usleep(10000);
+ absl::SleepFor(absl::Milliseconds(10));
}
struct msghdr msg = {};
@@ -610,6 +610,25 @@ TEST_P(TcpSocketTest, TcpInq) {
}
}
+TEST_P(TcpSocketTest, Tiocinq) {
+ char buf[1024];
+ size_t size = sizeof(buf);
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, size), SyscallSucceedsWithValue(size));
+
+ uint32_t seed = time(nullptr);
+ const size_t max_chunk = size / 10;
+ while (size > 0) {
+ size_t chunk = (rand_r(&seed) % max_chunk) + 1;
+ ssize_t read = RetryEINTR(recvfrom)(t_, buf, chunk, 0, nullptr, nullptr);
+ ASSERT_THAT(read, SyscallSucceeds());
+ size -= read;
+
+ int inq = 0;
+ ASSERT_THAT(ioctl(t_, TIOCINQ, &inq), SyscallSucceeds());
+ ASSERT_EQ(inq, size);
+ }
+}
+
TEST_P(TcpSocketTest, TcpSCMPriority) {
char buf[1024];
ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)),
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index fd42e81e1..3db18d7ac 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -23,6 +23,7 @@
#include <atomic>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/cleanup.h"
@@ -33,8 +34,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(timers_test_sleep, false,
- "If true, sleep forever instead of running tests.");
+ABSL_FLAG(bool, timers_test_sleep, false,
+ "If true, sleep forever instead of running tests.");
using ::testing::_;
using ::testing::AnyOf;
@@ -635,7 +636,7 @@ TEST(IntervalTimerTest, IgnoredSignalCountsAsOverrun) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_timers_test_sleep) {
+ if (absl::GetFlag(FLAGS_timers_test_sleep)) {
while (true) {
absl::SleepFor(absl::Seconds(10));
}
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
index bf1ca8679..6218fbce1 100644
--- a/test/syscalls/linux/uidgid.cc
+++ b/test/syscalls/linux/uidgid.cc
@@ -18,17 +18,19 @@
#include <unistd.h>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "test/util/capability_util.h"
#include "test/util/posix_error.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+#include "test/util/uid_util.h"
-DEFINE_int32(scratch_uid1, 65534, "first scratch UID");
-DEFINE_int32(scratch_uid2, 65533, "second scratch UID");
-DEFINE_int32(scratch_gid1, 65534, "first scratch GID");
-DEFINE_int32(scratch_gid2, 65533, "second scratch GID");
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID");
using ::testing::UnorderedElementsAreArray;
@@ -67,30 +69,6 @@ TEST(UidGidTest, Getgroups) {
// here; see the setgroups test below.
}
-// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns
-// true. Otherwise IsRoot logs an explanatory message and returns false.
-PosixErrorOr<bool> IsRoot() {
- uid_t ruid, euid, suid;
- int rc = getresuid(&ruid, &euid, &suid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresuid");
- }
- if (ruid != 0 || euid != 0 || suid != 0) {
- return false;
- }
- gid_t rgid, egid, sgid;
- rc = getresgid(&rgid, &egid, &sgid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresgid");
- }
- if (rgid != 0 || egid != 0 || sgid != 0) {
- return false;
- }
- return true;
-}
-
// Checks that the calling process' real/effective/saved user IDs are
// ruid/euid/suid respectively.
PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) {
@@ -146,7 +124,7 @@ TEST(UidGidRootTest, Setuid) {
// real UID.
EXPECT_THAT(syscall(SYS_setuid, -1), SyscallFailsWithErrno(EINVAL));
- const uid_t uid = FLAGS_scratch_uid1;
+ const uid_t uid = absl::GetFlag(FLAGS_scratch_uid1);
EXPECT_THAT(syscall(SYS_setuid, uid), SyscallSucceeds());
// "If the effective UID of the caller is root (more precisely: if the
// caller has the CAP_SETUID capability), the real UID and saved set-user-ID
@@ -160,7 +138,7 @@ TEST(UidGidRootTest, Setgid) {
EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL));
- const gid_t gid = FLAGS_scratch_gid1;
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
ASSERT_THAT(setgid(gid), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid));
}
@@ -168,7 +146,7 @@ TEST(UidGidRootTest, Setgid) {
TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
- const gid_t gid = FLAGS_scratch_gid1;
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
// NOTE(b/64676707): Do setgid in a separate thread so that we can test if
// info.si_pid is set correctly.
ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); });
@@ -189,8 +167,8 @@ TEST(UidGidRootTest, Setreuid) {
// cannot be opened by the `uid` set below after the test. After calling
// setuid(non-zero-UID), there is no way to get root privileges back.
ScopedThread([&] {
- const uid_t ruid = FLAGS_scratch_uid1;
- const uid_t euid = FLAGS_scratch_uid2;
+ const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1);
+ const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2);
// Use syscall instead of glibc setuid wrapper because we want this setuid
// call to only apply to this task. posix threads, however, require that all
@@ -211,8 +189,8 @@ TEST(UidGidRootTest, Setregid) {
EXPECT_THAT(setregid(-1, -1), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0));
- const gid_t rgid = FLAGS_scratch_gid1;
- const gid_t egid = FLAGS_scratch_gid2;
+ const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1);
+ const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2);
ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid));
}
diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc
index 0a5d91017..d8824b171 100644
--- a/test/syscalls/linux/uname.cc
+++ b/test/syscalls/linux/uname.cc
@@ -41,6 +41,19 @@ TEST(UnameTest, Sanity) {
TEST(UnameTest, SetNames) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ char hostname[65];
+ ASSERT_THAT(sethostname("0123456789", 3), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "012");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 11), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 12), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
constexpr char kHostname[] = "wubbalubba";
ASSERT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds());
@@ -54,7 +67,6 @@ TEST(UnameTest, SetNames) {
EXPECT_EQ(absl::string_view(buf.domainname), kDomainname);
// These should just be glibc wrappers that also call uname(2).
- char hostname[65];
EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
EXPECT_EQ(absl::string_view(hostname), kHostname);
diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc
index b6f65e027..2040375c9 100644
--- a/test/syscalls/linux/unlink.cc
+++ b/test/syscalls/linux/unlink.cc
@@ -123,6 +123,8 @@ TEST(UnlinkTest, AtBad) {
SyscallSucceeds());
EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", AT_REMOVEDIR),
SyscallFailsWithErrno(ENOTDIR));
+ EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile/", 0),
+ SyscallFailsWithErrno(ENOTDIR));
ASSERT_THAT(close(fd), SyscallSucceeds());
EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds());
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
index f67b06f37..0aaba482d 100644
--- a/test/syscalls/linux/vfork.cc
+++ b/test/syscalls/linux/vfork.cc
@@ -22,14 +22,15 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/time.h"
#include "test/util/logging.h"
#include "test/util/multiprocess_util.h"
#include "test/util/test_util.h"
#include "test/util/time_util.h"
-DEFINE_bool(vfork_test_child, false,
- "If true, run the VforkTest child workload.");
+ABSL_FLAG(bool, vfork_test_child, false,
+ "If true, run the VforkTest child workload.");
namespace gvisor {
namespace testing {
@@ -186,7 +187,7 @@ int RunChild() {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_vfork_test_child) {
+ if (absl::GetFlag(FLAGS_vfork_test_child)) {
return gvisor::testing::RunChild();
}
diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go
index 32408f021..c1e9ce22c 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/syscalls/syscall_test_runner.go
@@ -20,12 +20,10 @@ import (
"flag"
"fmt"
"io/ioutil"
- "math"
"os"
"os/exec"
"os/signal"
"path/filepath"
- "strconv"
"strings"
"syscall"
"testing"
@@ -35,7 +33,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
"gvisor.dev/gvisor/test/syscalls/gtest"
)
@@ -358,32 +356,14 @@ func main() {
fatalf("ParseTestCases(%q) failed: %v", testBin, err)
}
- // If sharding, then get the subset of tests to run based on the shard index.
- if indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS"); indexStr != "" && totalStr != "" {
- // Parse index and total to ints.
- index, err := strconv.Atoi(indexStr)
- if err != nil {
- fatalf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
- }
- total, err := strconv.Atoi(totalStr)
- if err != nil {
- fatalf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
- }
- // Calculate subslice of tests to run.
- shardSize := int(math.Ceil(float64(len(testCases)) / float64(total)))
- begin := index * shardSize
- // Set end as begin of next subslice.
- end := ((index + 1) * shardSize)
- if begin > len(testCases) {
- // Nothing to run.
- return
- }
- if end > len(testCases) {
- end = len(testCases)
- }
- testCases = testCases[begin:end]
+ // Get subset of tests corresponding to shard.
+ begin, end, err := testutil.TestBoundsForShard(len(testCases))
+ if err != nil {
+ fatalf("TestsForShard() failed: %v", err)
}
+ testCases = testCases[begin:end]
+ // Run the tests.
var tests []testing.InternalTest
for _, tc := range testCases {
// Capture tc.
diff --git a/test/util/BUILD b/test/util/BUILD
index c124cef34..5d2a9cc2c 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -1,3 +1,6 @@
+load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
+load("//test/syscalls:build_defs.bzl", "select_for_linux")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
@@ -139,7 +142,11 @@ cc_library(
cc_library(
name = "save_util",
testonly = 1,
- srcs = ["save_util.cc"],
+ srcs = ["save_util.cc"] +
+ select_for_linux(
+ ["save_util_linux.cc"],
+ ["save_util_other.cc"],
+ ),
hdrs = ["save_util.h"],
)
@@ -232,8 +239,9 @@ cc_library(
":logging",
":posix_error",
":save_util",
- "@com_github_gflags_gflags//:gflags",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
@@ -316,3 +324,14 @@ cc_library(
":test_util",
],
)
+
+cc_library(
+ name = "uid_util",
+ testonly = 1,
+ srcs = ["uid_util.cc"],
+ hdrs = ["uid_util.h"],
+ deps = [
+ ":posix_error",
+ ":save_util",
+ ],
+)
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index ae49725a0..f7d231b14 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -105,6 +105,15 @@ PosixErrorOr<struct stat> Stat(absl::string_view path) {
return stat_buf;
}
+PosixErrorOr<struct stat> Fstat(int fd) {
+ struct stat stat_buf;
+ int res = fstat(fd, &stat_buf);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("fstat ", fd));
+ }
+ return stat_buf;
+}
+
PosixErrorOr<bool> Exists(absl::string_view path) {
struct stat stat_buf;
int res = stat(std::string(path).c_str(), &stat_buf);
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index 3969f8309..e5b555891 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -35,6 +35,9 @@ PosixErrorOr<bool> Exists(absl::string_view path);
// Returns a stat structure for the given path or an error.
PosixErrorOr<struct stat> Stat(absl::string_view path);
+// Returns a stat struct for the given fd.
+PosixErrorOr<struct stat> Fstat(int fd);
+
// Deletes the file or directory at path or returns an error.
PosixError Delete(absl::string_view path);
diff --git a/test/util/memory_util.h b/test/util/memory_util.h
index 190c469b5..e189b73e8 100644
--- a/test/util/memory_util.h
+++ b/test/util/memory_util.h
@@ -118,6 +118,18 @@ inline PosixErrorOr<Mapping> MmapAnon(size_t length, int prot, int flags) {
return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0);
}
+// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of
+// void* isn't directly compatible with SyscallSucceeds.
+inline PosixErrorOr<void*> Mremap(void* old_address, size_t old_size,
+ size_t new_size, int flags,
+ void* new_address) {
+ void* rv = mremap(old_address, old_size, new_size, flags, new_address);
+ if (rv == MAP_FAILED) {
+ return PosixError(errno, "mremap failed");
+ }
+ return rv;
+}
+
// Returns true if the page containing addr is mapped.
inline bool IsMapped(uintptr_t addr) {
int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)),
diff --git a/test/util/proc_util.cc b/test/util/proc_util.cc
index 75b24da37..34d636ba9 100644
--- a/test/util/proc_util.cc
+++ b/test/util/proc_util.cc
@@ -88,7 +88,7 @@ PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps(
std::vector<ProcMapsEntry> entries;
auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
for (const auto& l : lines) {
- std::cout << "line: " << l;
+ std::cout << "line: " << l << std::endl;
ASSIGN_OR_RETURN_ERRNO(auto entry, ParseProcMapsLine(l));
entries.push_back(entry);
}
diff --git a/test/util/save_util.cc b/test/util/save_util.cc
index 05f52b80d..384d626f0 100644
--- a/test/util/save_util.cc
+++ b/test/util/save_util.cc
@@ -16,8 +16,8 @@
#include <stddef.h>
#include <stdlib.h>
-#include <sys/syscall.h>
#include <unistd.h>
+
#include <atomic>
#include <cerrno>
@@ -61,13 +61,11 @@ void DisableSave::reset() {
}
}
-void MaybeSave() {
- if (CooperativeSaveEnabled() && !save_disable.load()) {
- int orig_errno = errno;
- syscall(SYS_create_module, nullptr, 0);
- errno = orig_errno;
- }
+namespace internal {
+bool ShouldSave() {
+ return CooperativeSaveEnabled() && (save_disable.load() == 0);
}
+} // namespace internal
} // namespace testing
} // namespace gvisor
diff --git a/test/util/save_util.h b/test/util/save_util.h
index 90460701e..bddad6120 100644
--- a/test/util/save_util.h
+++ b/test/util/save_util.h
@@ -41,6 +41,11 @@ class DisableSave {
//
// errno is guaranteed to be preserved.
void MaybeSave();
+
+namespace internal {
+bool ShouldSave();
+} // namespace internal
+
} // namespace testing
} // namespace gvisor
diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc
new file mode 100644
index 000000000..7a0f14342
--- /dev/null
+++ b/test/util/save_util_linux.cc
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+void MaybeSave() {
+ if (internal::ShouldSave()) {
+ int orig_errno = errno;
+ syscall(SYS_create_module, nullptr, 0);
+ errno = orig_errno;
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/runsc/test/testutil/testutil_race.go b/test/util/save_util_other.cc
index 86db6ffa1..1aca663b7 100644
--- a/runsc/test/testutil/testutil_race.go
+++ b/test/util/save_util_other.cc
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build race
+namespace gvisor {
+namespace testing {
-package testutil
-
-func init() {
- RaceEnabled = true
+void MaybeSave() {
+ // Saving is never available in a non-linux environment.
}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/test_util.cc b/test/util/test_util.cc
index e42bba04a..ba0dcf7d0 100644
--- a/test/util/test_util.cc
+++ b/test/util/test_util.cc
@@ -28,6 +28,8 @@
#include <vector>
#include "absl/base/attributes.h"
+#include "absl/flags/flag.h" // IWYU pragma: keep
+#include "absl/flags/parse.h" // IWYU pragma: keep
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
@@ -224,7 +226,7 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) {
void TestInit(int* argc, char*** argv) {
::testing::InitGoogleTest(argc, *argv);
- ::gflags::ParseCommandLineFlags(argc, argv, true);
+ ::absl::ParseCommandLine(*argc, *argv);
// Always mask SIGPIPE as it's common and tests aren't expected to handle it.
struct sigaction sa = {};
diff --git a/test/util/test_util.h b/test/util/test_util.h
index cdbe8bfd1..b9d2dc2ba 100644
--- a/test/util/test_util.h
+++ b/test/util/test_util.h
@@ -185,7 +185,6 @@
#include <utility>
#include <vector>
-#include <gflags/gflags.h>
#include "gmock/gmock.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
diff --git a/test/util/thread_util.h b/test/util/thread_util.h
index 860e77531..923c4fe10 100644
--- a/test/util/thread_util.h
+++ b/test/util/thread_util.h
@@ -16,7 +16,9 @@
#define GVISOR_TEST_UTIL_THREAD_UTIL_H_
#include <pthread.h>
+#ifdef __linux__
#include <sys/syscall.h>
+#endif
#include <unistd.h>
#include <functional>
@@ -66,13 +68,13 @@ class ScopedThread {
private:
void CreateThread() {
- TEST_PCHECK_MSG(
- pthread_create(&pt_, /* attr = */ nullptr,
- +[](void* arg) -> void* {
- return static_cast<ScopedThread*>(arg)->f_();
- },
- this) == 0,
- "thread creation failed");
+ TEST_PCHECK_MSG(pthread_create(
+ &pt_, /* attr = */ nullptr,
+ +[](void* arg) -> void* {
+ return static_cast<ScopedThread*>(arg)->f_();
+ },
+ this) == 0,
+ "thread creation failed");
}
std::function<void*()> f_;
@@ -81,7 +83,9 @@ class ScopedThread {
void* retval_ = nullptr;
};
+#ifdef __linux__
inline pid_t gettid() { return syscall(SYS_gettid); }
+#endif
} // namespace testing
} // namespace gvisor
diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc
new file mode 100644
index 000000000..b131b4b99
--- /dev/null
+++ b/test/util/uid_util.cc
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<bool> IsRoot() {
+ uid_t ruid, euid, suid;
+ int rc = getresuid(&ruid, &euid, &suid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresuid");
+ }
+ if (ruid != 0 || euid != 0 || suid != 0) {
+ return false;
+ }
+ gid_t rgid, egid, sgid;
+ rc = getresgid(&rgid, &egid, &sgid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresgid");
+ }
+ if (rgid != 0 || egid != 0 || sgid != 0) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/uid_util.h b/test/util/uid_util.h
new file mode 100644
index 000000000..2cd387fb0
--- /dev/null
+++ b/test/util/uid_util.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns true if the caller's real/effective/saved user/group IDs are all 0.
+PosixErrorOr<bool> IsRoot();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_
diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/third_party/gvsync/downgradable_rwmutex_unsafe.go
index 069939033..1f6007aa1 100644
--- a/third_party/gvsync/downgradable_rwmutex_unsafe.go
+++ b/third_party/gvsync/downgradable_rwmutex_unsafe.go
@@ -57,9 +57,6 @@ func (rw *DowngradableRWMutex) RLock() {
// RUnlock undoes a single RLock call.
func (rw *DowngradableRWMutex) RUnlock() {
if RaceEnabled {
- // TODO(jamieliu): Why does this need to be ReleaseMerge instead of
- // Release? IIUC this establishes Unlock happens-before RUnlock, which
- // seems unnecessary.
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
RaceDisable()
}
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index d9e79401d..ddb9b6e7b 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -59,7 +59,11 @@ git checkout -b go "${go_branch}"
# Start working on a merge commit that combines the previous history with the
# current history. Note that we don't actually want any changes yet.
-git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
+#
+# N.B. The git behavior changed at some point and the relevant flag was added
+# to allow for override, so try the only behavior first then pass the flag.
+git merge --no-commit --strategy ours ${head} || \
+ git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
# Sync the entire gopath_dir and go.mod.
rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" .
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
new file mode 100644
index 000000000..c862b277c
--- /dev/null
+++ b/tools/go_marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "go_marshal",
+ srcs = ["main.go"],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//tools/go_marshal/gomarshal",
+ ],
+)
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
new file mode 100644
index 000000000..481575bd3
--- /dev/null
+++ b/tools/go_marshal/README.md
@@ -0,0 +1,164 @@
+This package implements the go_marshal utility.
+
+# Overview
+
+`go_marshal` is a code generation utility similar to `go_stateify` for
+automatically generating code to marshal go data structures to memory.
+
+`go_marshal` attempts to improve on `binary.Write` and the sentry's
+`binary.Marshal` by moving the go runtime reflection necessary to marshal a
+struct to compile-time.
+
+`go_marshal` automatically generates implementations for `abi.Marshallable` and
+`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall
+implementations) can directly invoke `safemem.Reader.ReadToBlocks` and
+`safemem.Writer.WriteFromBlocks`. Data structures that require custom
+serialization will have manual implementations for these interfaces.
+
+Data structures can be flagged for code generation by adding a struct-level
+comment `// +marshal`.
+
+# Usage
+
+See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`.
+
+The recommended way to generate a go library with marshalling is to use the
+`go_library` with mostly identical configuration as the native go_library rule.
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+```
+
+Under the hood, the `go_marshal` rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (note that the above is the preferred method):
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "<PKGPATH>/gvisor/pkg/abi",
+ "<PKGPATH>/gvisor/pkg/sentry/safemem/safemem",
+ "<PKGPATH>/gvisor/pkg/sentry/usermem/usermem",
+ ],
+)
+```
+
+As part of the interface generation, `go_marshal` also generates some tests for
+sanity checking the struct definitions for potential alignment issues, and a
+simple round-trip test through Marshal/Unmarshal to verify the implementation.
+These tests use reflection to verify properties of the ABI struct, and should be
+considered part of the generated interfaces (but are too expensive to execute at
+runtime). Ensure these tests run at some point.
+
+```
+$ cat BUILD
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+$ blaze build :foo
+$ blaze query ...
+<path-to-dir>:foo_abi_autogen
+<path-to-dir>:foo_abi_autogen_test
+$ blaze test :foo_abi_autogen_test
+<test-output>
+```
+
+# Restrictions
+
+Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is
+intended for ABI structs, which have these additional restrictions:
+
+- At the moment, `go_marshal` only supports struct declarations.
+
+- Structs are marshalled as packed types. This means no implicit padding is
+ inserted between fields shorter than the platform register size. For
+ alignment, manually insert padding fields.
+
+- Structs used with `go_marshal` must have a compile-time static size. This
+ means no dynamically sizes fields like slices or strings. Use statically
+ sized array (byte arrays for strings) instead.
+
+- No pointers, channel, map or function pointer fields, and no fields that are
+ arrays of these types. These don't make sense in an ABI data structure.
+
+- We could support opaque pointers as `uintptr`, but this is currently not
+ implemented. Implementing this would require handling the architecture
+ dependent native pointer size.
+
+- Fields must either be a primitive integer type (`byte`,
+ `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable.
+
+- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric
+ type.
+
+- `float*` fields are currently not supported, but could be if necessary.
+
+# Appendix
+
+## Working with Non-Packed Structs
+
+ABI structs must generally be packed types, meaning they should have no implicit
+padding between short fields. However, if a field is tagged
+`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower
+mechanism to deal with potentially unaligned fields.
+
+Note that the non-packed property is inheritted by any other struct that embeds
+this struct, since the `go_marshal` tool currently can't reason about alignments
+for embedded structs that are not aligned.
+
+Because of this, it's generally best to avoid using `marshal:"unaligned"` and
+insert explicit padding fields instead.
+
+## Debugging go_marshal
+
+To enable debugging output from the go marshal tool, pass the `-debug` flag to
+the tool. When using the build rules from above, add a `debug = True` field to
+the build rule like this:
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+ debug = True,
+)
+```
+
+## Modifying the `go_marshal` Tool
+
+The following are some guidelines for modifying the `go_marshal` tool:
+
+- The `go_marshal` tool currently does a single pass over all types requesting
+ code generation, in arbitrary order. This means the generated code can't
+ directly obtain information about embedded marshallable types at
+ compile-time. One way to work around this restriction is to add a new
+ Marshallable interface method providing this piece of information, and
+ calling it from the generated code. Use this sparingly, as we want to rely
+ on compile-time information as much as possible for performance.
+
+- No runtime reflection in the code generated for the marshallable interface.
+ The entire point of the tool is to avoid runtime reflection. The generated
+ tests may use reflection.
diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD
new file mode 100644
index 000000000..c859ced77
--- /dev/null
+++ b/tools/go_marshal/analysis/BUILD
@@ -0,0 +1,13 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "analysis",
+ testonly = 1,
+ srcs = ["analysis_unsafe.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
new file mode 100644
index 000000000..9a9a4f298
--- /dev/null
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -0,0 +1,175 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package analysis implements common functionality used by generated
+// go_marshal tests.
+package analysis
+
+// All functions in this package are unsafe and are not intended for general
+// consumption. They contain sharp edge cases and the caller is responsible for
+// ensuring none of them are hit. Callers must be carefully to pass in only sane
+// arguments. Failure to do so may cause panics at best and arbitrary memory
+// corruption at worst.
+//
+// Never use outside of tests.
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+// RandomizeValue assigns random value(s) to an abitrary type. This is intended
+// for used with ABI structs from go_marshal, meaning the typical restrictions
+// apply (fixed-size types, no pointers, maps, channels, etc), and should only
+// be used on zeroed values to avoid overwriting pointers to active go objects.
+//
+// Internally, we populate the type with random data by doing an unsafe cast to
+// access the underlying memory of the type and filling it as if it were a byte
+// slice. This almost gets us what we want, but padding fields named "_" are
+// normally not accessible, so we walk the type and recursively zero all "_"
+// fields.
+//
+// Precondition: x must be a pointer. x must not contain any valid
+// pointers to active go objects (pointer fields aren't allowed in ABI
+// structs anyways), or we'd be violating the go runtime contract and
+// the GC may malfunction.
+func RandomizeValue(x interface{}) {
+ v := reflect.Indirect(reflect.ValueOf(x))
+ if !v.CanSet() {
+ panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument")
+ }
+
+ // Cast the underlying memory for the type into a byte slice.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer.
+ hdr.Data = v.UnsafeAddr()
+ hdr.Len = int(v.Type().Size())
+ hdr.Cap = hdr.Len
+
+ // Fill the byte slice with random data, which in effect fills the type with
+ // random values.
+ n, err := rand.Read(b)
+ if err != nil || n != len(b) {
+ panic("unreachable")
+ }
+
+ // Normally, padding fields are not accessible, so zero them out.
+ reflectZeroPaddingFields(v.Type(), b, false)
+}
+
+// reflectZeroPaddingFields assigns zero values to padding fields for the value
+// of type r, represented by the memory in data. Padding fields are defined as
+// fields with the name "_". If zero is true, the immediate value itself is
+// zeroed. In addition, the type is recursively scanned for padding fields in
+// inner types.
+//
+// This is used for zeroing padding fields after calling RandomizeValue.
+func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) {
+ if zero {
+ for i, _ := range data {
+ data[i] = 0
+ }
+ }
+ switch r.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // These types are explicitly allowed in an ABI type, but we don't need
+ // to recurse further as they're scalar types.
+ case reflect.Struct:
+ for i, numFields := 0, r.NumField(); i < numFields; i++ {
+ f := r.Field(i)
+ off := f.Offset
+ len := f.Type.Size()
+ window := data[off : off+len]
+ reflectZeroPaddingFields(f.Type, window, f.Name == "_")
+ }
+ case reflect.Array:
+ eLen := int(r.Elem().Size())
+ if int(r.Size()) != eLen*r.Len() {
+ panic("Array has unexpected size?")
+ }
+ for i, n := 0, r.Len(); i < n; i++ {
+ reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false)
+ }
+ default:
+ panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind()))
+
+ }
+}
+
+// AlignmentCheck ensures the definition of the type represented by typ doesn't
+// cause the go compiler to emit implicit padding between elements of the type
+// (i.e. fields in a struct).
+//
+// AlignmentCheck doesn't explicitly recurse for embedded structs because any
+// struct present in an ABI struct must also be Marshallable, and therefore
+// they're aligned by definition (or their alignment check would have failed).
+func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
+ switch typ.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // Primitive types are always considered well aligned. Primitive types
+ // that are fields in structs are checked independently, this branch
+ // exists to handle recursive calls to alignmentCheck.
+ case reflect.Struct:
+ xOff := 0
+ nextXOff := 0
+ skipNext := false
+ for i, numFields := 0, typ.NumField(); i < numFields; i++ {
+ xOff = nextXOff
+ f := typ.Field(i)
+ fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size())
+ nextXOff = int(f.Offset + f.Type.Size())
+
+ if f.Name == "_" {
+ // Padding fields need not be aligned.
+ fmt.Printf("Padding field of type %v\n", f.Type)
+ continue
+ }
+
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ skipNext = true
+ continue
+ }
+
+ if skipNext {
+ skipNext = false
+ fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name)
+ continue
+ }
+
+ if xOff != int(f.Offset) {
+ implicitPad := int(f.Offset) - xOff
+ t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad)
+ }
+ }
+
+ // Ensure structs end on a byte explicitly defined by the type.
+ if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
+ implicitPad := int(typ.Size()) - nextXOff
+ f := typ.Field(typ.NumField() - 1) // Final field
+ t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
+ typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
+ }
+ case reflect.Array:
+ // Independent arrays are also always considered well aligned. We only
+ // need to worry about their alignment when they're embedded in structs,
+ // which we handle above.
+ default:
+ t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind())
+ }
+ return true, uint64(typ.Size())
+}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
new file mode 100644
index 000000000..c32eb559f
--- /dev/null
+++ b/tools/go_marshal/defs.bzl
@@ -0,0 +1,152 @@
+"""Marshal is a tool for generating marshalling interfaces for Go types.
+
+The recommended way is to use the go_library rule defined below with mostly
+identical configuration as the native go_library rule.
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+
+Under the hood, the go_marshal rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (the above is still the preferred way):
+
+load("//tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "//tools/go_marshal:marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ],
+)
+"""
+
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+
+def _go_marshal_impl(ctx):
+ """Execute the go_marshal tool."""
+ output = ctx.outputs.lib
+ output_test = ctx.outputs.test
+ (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD")
+
+ decl = "/".join(["gvisor.dev/gvisor", build_dir])
+
+ # Run the marshal command.
+ args = ["-output=%s" % output.path]
+ args += ["-pkg=%s" % ctx.attr.package]
+ args += ["-output_test=%s" % output_test.path]
+ args += ["-declarationPkg=%s" % decl]
+
+ if ctx.attr.debug:
+ args += ["-debug"]
+
+ args += ["--"]
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files.to_list()]
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output, output_test],
+ mnemonic = "GoMarshal",
+ progress_message = "go_marshal: %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+# Generates save and restore logic from a set of Go files.
+#
+# Args:
+# name: the name of the rule.
+# srcs: the input source files. These files should include all structs in the
+# package that need to be saved.
+# imports: an optional list of extra, non-aliased, Go-style absolute import
+# paths.
+# out: the name of the generated file output. This must not conflict with any
+# other files and must be added to the srcs of the relevant go_library.
+# package: the package name for the input sources.
+go_marshal = rule(
+ implementation = _go_marshal_impl,
+ attrs = {
+ "srcs": attr.label_list(mandatory = True, allow_files = True),
+ "libname": attr.string(mandatory = True),
+ "imports": attr.string_list(mandatory = False),
+ "package": attr.string(mandatory = True),
+ "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")),
+ },
+ outputs = {
+ "lib": "%{name}_unsafe.go",
+ "test": "%{name}_test.go",
+ },
+)
+
+def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs):
+ """wraps the standard go_library and does mashalling interface generation.
+
+ Args:
+ name: Same as native go_library.
+ srcs: Same as native go_library.
+ deps: Same as native go_library.
+ imports: Extra import paths to pass to the go_marshal tool.
+ debug: Enables debugging output from the go_marshal tool.
+ **kwargs: Remaining args to pass to the native go_library rule unmodified.
+ """
+ go_marshal(
+ name = name + "_abi_autogen",
+ libname = name,
+ srcs = [src for src in srcs if src.endswith(".go")],
+ debug = debug,
+ imports = imports,
+ package = name,
+ )
+
+ extra_deps = [
+ "//tools/go_marshal/marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ]
+
+ all_srcs = srcs + [name + "_abi_autogen_unsafe.go"]
+ all_deps = deps + [] # + extra_deps
+
+ for extra in extra_deps:
+ if extra not in deps:
+ all_deps.append(extra)
+
+ _go_library(
+ name = name,
+ srcs = all_srcs,
+ deps = all_deps,
+ **kwargs
+ )
+
+ # Don't pass importpath arg to go_test.
+ kwargs.pop("importpath", "")
+
+ _go_test(
+ name = name + "_abi_autogen_test",
+ srcs = [name + "_abi_autogen_test.go"],
+ # Generated test has a fixed set of dependencies since we generate these
+ # tests. They should only depend on the library generated above, and the
+ # Marshallable interface.
+ deps = [
+ ":" + name,
+ "//tools/go_marshal/analysis",
+ ],
+ **kwargs
+ )
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
new file mode 100644
index 000000000..a0eae6492
--- /dev/null
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gomarshal",
+ srcs = [
+ "generator.go",
+ "generator_interfaces.go",
+ "generator_tests.go",
+ "util.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
new file mode 100644
index 000000000..641ccd938
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -0,0 +1,382 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gomarshal implements the go_marshal code generator. See README.md.
+package gomarshal
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "sort"
+)
+
+const (
+ marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem"
+ safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+)
+
+// List of identifiers we use in generated code, that may conflict a
+// similarly-named source identifier. Avoid problems by refusing the generate
+// code when we see these.
+//
+// This only applies to import aliases at the moment. All other identifiers
+// are qualified by a receiver argument, since they're struct fields.
+//
+// All recievers are single letters, so we don't allow import aliases to be a
+// single letter.
+var badIdents = []string{
+ "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ // All single-letter identifiers.
+}
+
+// Generator drives code generation for a single invocation of the go_marshal
+// utility.
+//
+// The Generator holds arguments passed to the tool, and drives parsing,
+// processing and code Generator for all types marked with +marshal declared in
+// the input files.
+//
+// See Generator.run() as the entry point.
+type Generator struct {
+ // Paths to input go source files.
+ inputs []string
+ // Output file to write generated go source.
+ output *os.File
+ // Output file to write generated tests.
+ outputTest *os.File
+ // Package name for the generated file.
+ pkg string
+ // Go import path for package we're processing. This package should directly
+ // declare the type we're generating code for.
+ declaration string
+ // Set of extra packages to import in the generated file.
+ imports *importTable
+}
+
+// NewGenerator creates a new code Generator.
+func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) {
+ f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
+ }
+ fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
+ }
+ g := Generator{
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ pkg: pkg,
+ declaration: declaration,
+ imports: newImportTable(),
+ }
+ for _, i := range imports {
+ // All imports on the extra imports list are unconditionally marked as
+ // used, so they're always added to the generated code.
+ g.imports.add(i).markUsed()
+ }
+ g.imports.add(marshalImport).markUsed()
+ // The follow imports may or may not be used by the generated
+ // code, depending what's required for the target types. Don't
+ // mark these imports as used by default.
+ g.imports.add(usermemImport)
+ g.imports.add(safecopyImport)
+ g.imports.add("unsafe")
+
+ return &g, nil
+}
+
+// writeHeader writes the header for the generated source file. The header
+// includes the package name, package level comments and import statements.
+func (g *Generator) writeHeader() error {
+ var b sourceBuffer
+ b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+ b.emit("package %s\n\n", g.pkg)
+ if err := b.write(g.output); err != nil {
+ return err
+ }
+
+ return g.imports.write(g.output)
+}
+
+// writeTypeChecks writes a statement to force the compiler to perform a type
+// check for all Marshallable types referenced by the generated code.
+func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
+ if len(ms) == 0 {
+ return nil
+ }
+
+ msl := make([]string, 0, len(ms))
+ for m, _ := range ms {
+ msl = append(msl, m)
+ }
+ sort.Strings(msl)
+
+ var buf bytes.Buffer
+ fmt.Fprint(&buf, "// Marshallable types used by this file.\n")
+
+ for _, m := range msl {
+ fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
+ }
+ fmt.Fprint(&buf, "\n")
+
+ _, err := fmt.Fprint(g.output, buf.String())
+ return err
+}
+
+// parse processes all input files passed this generator and produces a set of
+// parsed go ASTs.
+func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
+ debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
+ for _, path := range g.inputs {
+ debugf(" %s\n", path)
+ }
+
+ files := make([]*ast.File, 0, len(g.inputs))
+ fsets := make([]*token.FileSet, 0, len(g.inputs))
+
+ for _, path := range g.inputs {
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
+ if err != nil {
+ // Not a valid input file?
+ return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err)
+ }
+
+ if debugEnabled() {
+ debugf("AST for %q:\n", path)
+ ast.Print(fset, f)
+ }
+
+ files = append(files, f)
+ fsets = append(fsets, fset)
+ }
+
+ return files, fsets, nil
+}
+
+// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// declarations for which we need to generate the Marshallable interface.
+func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
+ var types []*ast.TypeSpec
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Type declaration?
+ if !ok || gdecl.Tok != token.TYPE {
+ debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
+ continue
+ }
+ // Does it have a comment?
+ if gdecl.Doc == nil {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
+ continue
+ }
+ // Does the comment contain a "+marshal" line?
+ marked := false
+ for _, c := range gdecl.Doc.List {
+ if c.Text == "// +marshal" {
+ marked = true
+ break
+ }
+ }
+ if !marked {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ // We already confirmed we're in a type declaration earlier.
+ t := spec.(*ast.TypeSpec)
+ if _, ok := t.Type.(*ast.StructType); ok {
+ debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
+ }
+ debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ }
+ }
+ return types
+}
+
+// collectImports collects all imports from all input source files. Some of
+// these imports are copied to the generated output, if they're referenced by
+// the generated code.
+//
+// collectImports de-duplicates imports while building the list, and ensures
+// identifiers in the generated code don't conflict with any imported package
+// names.
+func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
+ badImportNames := make(map[string]bool)
+ for _, i := range badIdents {
+ badImportNames[i] = true
+ }
+
+ is := make(map[string]importStmt)
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Import statement?
+ if !ok || gdecl.Tok != token.IMPORT {
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
+ debugf("Collected import '%s' as '%s'\n", i.path, i.name)
+
+ // Make sure we have an import that doesn't use any local names that
+ // would conflict with identifiers in the generated code.
+ if len(i.name) == 1 {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
+ }
+ if badImportNames[i.name] {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
+ }
+ }
+ }
+ return is
+
+}
+
+func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ // We're guaranteed to have only struct type specs by now. See
+ // Generator.collectMarshallabeTypes.
+ i := newInterfaceGenerator(t, fset)
+ i.validate()
+ i.emitMarshallable()
+ return i
+}
+
+// generateOneTestSuite generates a test suite for the automatically generated
+// implementations type t.
+func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
+ i := newTestGenerator(t, g.declaration)
+ i.emitTests()
+ return i
+}
+
+// Run is the entry point to code generation using g.
+//
+// Run parses all input source files specified in g and emits generated code.
+func (g *Generator) Run() error {
+ // Parse our input source files into ASTs and token sets.
+ asts, fsets, err := g.parse()
+ if err != nil {
+ return err
+ }
+
+ if len(asts) != len(fsets) {
+ panic("ASTs and FileSets don't match")
+ }
+
+ // Map of imports in source files; key = local package name, value = import
+ // path.
+ is := make(map[string]importStmt)
+ for i, a := range asts {
+ // Collect all imports from the source files. We may need to copy some
+ // of these to the generated code if they're referenced. This has to be
+ // done before the loop below because we need to process all ASTs before
+ // we start requesting imports to be copied one by one as we encounter
+ // them in each generated source.
+ for name, i := range g.collectImports(a, fsets[i]) {
+ is[name] = i
+ }
+ }
+
+ var impls []*interfaceGenerator
+ var ts []*testGenerator
+ // Set of Marshallable types referenced by generated code.
+ ms := make(map[string]struct{})
+ for i, a := range asts {
+ // Collect type declarations marked for code generation and generate
+ // Marshallable interfaces.
+ for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ impl := g.generateOne(t, fsets[i])
+ // Collect Marshallable types referenced by the generated code.
+ for ref, _ := range impl.ms {
+ ms[ref] = struct{}{}
+ }
+ impls = append(impls, impl)
+ // Collect imports referenced by the generated code and add them to
+ // the list of imports we need to copy to the generated code.
+ for name, _ := range impl.is {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ }
+ }
+ ts = append(ts, g.generateOneTestSuite(t))
+ }
+ }
+
+ // Tool was invoked with input files with no data structures marked for code
+ // generation. This is probably not what the user intended.
+ if len(impls) == 0 {
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
+ for _, i := range g.inputs {
+ fmt.Fprintf(&buf, " %s\n", i)
+ }
+ abort(buf.String())
+ }
+
+ // Write output file header. These include things like package name and
+ // import statements.
+ if err := g.writeHeader(); err != nil {
+ return err
+ }
+
+ // Write type checks for referenced marshallable types to output file.
+ if err := g.writeTypeChecks(ms); err != nil {
+ return err
+ }
+
+ // Write generated interfaces to output file.
+ for _, i := range impls {
+ if err := i.write(g.output); err != nil {
+ return err
+ }
+ }
+
+ // Write generated tests to test file.
+ return g.writeTests(ts)
+}
+
+// writeTests outputs tests for the generated interface implementations to a go
+// source file.
+func (g *Generator) writeTests(ts []*testGenerator) error {
+ var b sourceBuffer
+ b.emit("package %s_test\n\n", g.pkg)
+ if err := b.write(g.outputTest); err != nil {
+ return err
+ }
+
+ imports := newImportTable()
+ for _, t := range ts {
+ imports.merge(t.imports)
+ }
+
+ if err := imports.write(g.outputTest); err != nil {
+ return err
+ }
+
+ for _, t := range ts {
+ if err := t.write(g.outputTest); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
new file mode 100644
index 000000000..a712c14dc
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -0,0 +1,507 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+// interfaceGenerator generates marshalling interfaces for a single type.
+//
+// getState is not thread-safe.
+type interfaceGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // FileSet containing the tokens for the type we're processing.
+ f *token.FileSet
+
+ // is records external packages referenced by the generated implementation.
+ is map[string]struct{}
+
+ // ms records Marshallable types referenced by the generated implementation
+ // of t's interfaces.
+ ms map[string]struct{}
+
+ // as records embedded fields in t that are potentially not packed. The key
+ // is the accessor for the field.
+ as map[string]struct{}
+}
+
+// typeName returns the name of the type this g represents.
+func (g *interfaceGenerator) typeName() string {
+ return g.t.Name.Name
+}
+
+// newinterfaceGenerator creates a new interface generator.
+func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &interfaceGenerator{
+ t: t,
+ r: receiverName(t),
+ f: fset,
+ is: make(map[string]struct{}),
+ ms: make(map[string]struct{}),
+ as: make(map[string]struct{}),
+ }
+ g.recordUsedMarshallable(g.typeName())
+ return g
+}
+
+func (g *interfaceGenerator) recordUsedMarshallable(m string) {
+ g.ms[m] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordUsedImport(i string) {
+ g.is[i] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
+ g.as[fieldName] = struct{}{}
+}
+
+func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with a
+// reference position to the input source. Same as abortAt, but uses g to
+// resolve p to position.
+func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
+ abortAt(g.f.Position(p), msg)
+}
+
+// validate ensures the type we're working with can be marshalled. These checks
+// are done ahead of time and in one place so we can make assumptions later.
+func (g *interfaceGenerator) validate() {
+ g.forEachField(func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ g.forEachField(func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we
+ // provide suggestions for some disallowed types and reject
+ // them, then attempt to marshal any remaining types by
+ // invoking the marshal.Marshallable interface on them. If
+ // these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code
+ // will fail with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n, _ *ast.Ident, len int) {
+ a := f.Type.(*ast.ArrayType)
+ if a.Len == nil {
+ g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Len.(*ast.BasicLit); !ok {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+
+ if len <= 0 {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
+ }
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+// scalarSize returns the size of type identified by t. If t isn't a primitive
+// type, the size isn't known at code generation time, and must be resolved via
+// the marshal.Marshallable interface.
+func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
+ switch t.Name {
+ case "int8", "uint8", "byte":
+ return 1, false
+ case "int16", "uint16":
+ return 2, false
+ case "int32", "uint32":
+ return 4, false
+ case "int64", "uint64":
+ return 8, false
+ default:
+ return 0, true
+ }
+}
+
+func (g *interfaceGenerator) shift(bufVar string, n int) {
+ g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
+}
+
+func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
+ g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
+}
+
+func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 2)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 4)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8":
+ g.emit("%s = int8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "uint8":
+ g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "byte":
+ g.emit("%s = %s[0]\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+
+ case "int16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+ case "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+
+ case "int32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+ case "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+
+ case "int64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ case "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ g.recordPotentiallyNonPackedField(accessor)
+ }
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+func (g *interfaceGenerator) emitMarshallable() {
+ // Is g.t a packed struct without consideing field types?
+ thisPacked := true
+ g.forEachField(func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if thisPacked {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ thisPacked = false
+ }
+ }
+ }
+ })
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n)))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n, t *ast.Ident, len int) {
+ if len < 1 {
+ // Zero-length arrays should've been rejected by validate().
+ panic("unreachable")
+ }
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size * len
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
new file mode 100644
index 000000000..df25cb5b2
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -0,0 +1,154 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "io"
+ "strings"
+)
+
+var standardImports = []string{
+ "fmt",
+ "reflect",
+ "testing",
+ "gvisor.dev/gvisor/tools/go_marshal/analysis",
+}
+
+type testGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // Imports used by generated code.
+ imports *importTable
+
+ // Import statement for the package declaring the type we generated code
+ // for. We need this to construct test instances for the type, since the
+ // tests aren't written in the same package.
+ decl *importStmt
+}
+
+func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &testGenerator{
+ t: t,
+ r: receiverName(t),
+ imports: newImportTable(),
+ }
+
+ for _, i := range standardImports {
+ g.imports.add(i).markUsed()
+ }
+ g.decl = g.imports.add(declaration)
+ g.decl.markUsed()
+
+ return g
+}
+
+func (g *testGenerator) typeName() string {
+ return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name)
+}
+
+func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *testGenerator) testFuncName(base string) string {
+ return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
+}
+
+func (g *testGenerator) inTestFunction(name string, body func()) {
+ g.emit("func %s(t *testing.T) {\n", g.testFuncName(name))
+ g.inIndent(body)
+ g.emit("}\n\n")
+}
+
+func (g *testGenerator) emitTestNonZeroSize() {
+ g.inTestFunction("TestSizeNonZero", func() {
+ g.emit("x := &%s{}\n", g.typeName())
+ g.emit("if x.SizeBytes() == 0 {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSuspectAlignment() {
+ g.inTestFunction("TestSuspectAlignment", func() {
+ g.emit("x := %s{}\n", g.typeName())
+ g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
+ g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() {
+ g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("buf := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalBytes(buf)\n")
+ g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalUnsafe(bufUnsafe)\n\n")
+
+ g.emit("y.UnmarshalBytes(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("z.UnmarshalUnsafe(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, z) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n")
+ })
+ g.emit("}\n")
+ g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTests() {
+ g.emitTestNonZeroSize()
+ g.emitTestSuspectAlignment()
+ g.emitTestMarshalUnmarshalPreservesData()
+}
+
+func (g *testGenerator) write(out io.Writer) error {
+ return g.sourceBuffer.write(out)
+}
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
new file mode 100644
index 000000000..967537abf
--- /dev/null
+++ b/tools/go_marshal/gomarshal/util.go
@@ -0,0 +1,387 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "io"
+ "os"
+ "path"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+var debug = flag.Bool("debug", false, "enables debugging output")
+
+// receiverName returns an appropriate receiver name given a type spec.
+func receiverName(t *ast.TypeSpec) string {
+ if len(t.Name.Name) < 1 {
+ // Zero length type name?
+ panic("unreachable")
+ }
+ return strings.ToLower(t.Name.Name[:1])
+}
+
+// kindString returns a user-friendly representation of an AST expr type.
+func kindString(e ast.Expr) string {
+ switch e.(type) {
+ case *ast.Ident:
+ return "scalar"
+ case *ast.ArrayType:
+ return "array"
+ case *ast.StructType:
+ return "struct"
+ case *ast.StarExpr:
+ return "pointer"
+ case *ast.FuncType:
+ return "function"
+ case *ast.InterfaceType:
+ return "interface"
+ case *ast.MapType:
+ return "map"
+ case *ast.ChanType:
+ return "channel"
+ default:
+ return reflect.TypeOf(e).String()
+ }
+}
+
+// fieldDispatcher is a collection of callbacks for handling different types of
+// fields in a struct declaration.
+type fieldDispatcher struct {
+ primitive func(n, t *ast.Ident)
+ selector func(n, tX, tSel *ast.Ident)
+ array func(n, t *ast.Ident, size int)
+ unhandled func(n *ast.Ident)
+}
+
+// Precondition: All dispatch callbacks that will be invoked must be
+// provided. Embedded fields are not allowed, len(f.Names) >= 1.
+func (fd fieldDispatcher) dispatch(f *ast.Field) {
+ // Each field declaration may actually be multiple declarations of the same
+ // type. For example, consider:
+ //
+ // type Point struct {
+ // x, y, z int
+ // }
+ //
+ // We invoke the call-backs once per such instance. Embedded fields are not
+ // allowed, and results in a panic.
+ if len(f.Names) < 1 {
+ panic("Precondition not met: attempted to dispatch on embedded field")
+ }
+
+ for _, name := range f.Names {
+ switch v := f.Type.(type) {
+ case *ast.Ident:
+ fd.primitive(name, v)
+ case *ast.SelectorExpr:
+ fd.selector(name, v.X.(*ast.Ident), v.Sel)
+ case *ast.ArrayType:
+ len := 0
+ if v.Len != nil {
+ // Non-literal array length is handled by generatorInterfaces.validate().
+ if lenLit, ok := v.Len.(*ast.BasicLit); ok {
+ var err error
+ len, err = strconv.Atoi(lenLit.Value)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ switch t := v.Elt.(type) {
+ case *ast.Ident:
+ fd.array(name, t, len)
+ default:
+ fd.array(name, nil, len)
+ }
+ default:
+ fd.unhandled(name)
+ }
+ }
+}
+
+// debugEnabled indicates whether debugging is enabled for gomarshal.
+func debugEnabled() bool {
+ return *debug
+}
+
+// abort aborts the go_marshal tool with the given error message.
+func abort(msg string) {
+ if !strings.HasSuffix(msg, "\n") {
+ msg += "\n"
+ }
+ fmt.Print(msg)
+ os.Exit(1)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with
+// a reference position to the input source.
+func abortAt(p token.Position, msg string) {
+ abort(fmt.Sprintf("%v:\n %s\n", p, msg))
+}
+
+// debugf conditionally prints a debug message.
+func debugf(f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf(f, a...)
+ }
+}
+
+// debugfAt conditionally prints a debug message with a reference to a position
+// in the input source.
+func debugfAt(p token.Position, f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...))
+ }
+}
+
+// emit generates a line of code in the output file.
+//
+// emit is a wrapper around writing a formatted string to the output
+// buffer. emit can be invoked in one of two ways:
+//
+// (1) emit("some string")
+// When emit is called with a single string argument, it is simply copied to
+// the output buffer without any further formatting.
+// (2) emit(fmtString, args...)
+// emit can also be invoked in a similar fashion to *Printf() functions,
+// where the first argument is a format string.
+//
+// Calling emit with a single argument that is not a string will result in a
+// panic, as the caller's intent is ambiguous.
+func emit(out io.Writer, indent int, a ...interface{}) {
+ const spacesPerIndentLevel = 4
+
+ if len(a) < 1 {
+ panic("emit() called with no arguments")
+ }
+
+ if indent > 0 {
+ if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
+ // Writing to the emit output should not fail. Typically the output
+ // is a byte.Buffer; writes to these never fail.
+ panic(err)
+ }
+ }
+
+ first, ok := a[0].(string)
+ if !ok {
+ // First argument must be either the string to emit (case 1 from
+ // function-level comment), or a format string (case 2).
+ panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
+ }
+
+ if len(a) == 1 {
+ // Single string argument. Assume no formatting requested.
+ if _, err := fmt.Fprint(out, first); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+ return
+
+ }
+
+ // Formatting requested.
+ if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+}
+
+// sourceBuffer represents fragments of generated go source code.
+//
+// sourceBuffer provides a convenient way to build up go souce fragments in
+// memory. May be safely zero-value initialized. Not thread-safe.
+type sourceBuffer struct {
+ // Current indentation level.
+ indent int
+
+ // Memory buffer containing contents while they're being generated.
+ b bytes.Buffer
+}
+
+func (b *sourceBuffer) incIndent() {
+ b.indent++
+}
+
+func (b *sourceBuffer) decIndent() {
+ if b.indent <= 0 {
+ panic("decIndent() without matching incIndent()")
+ }
+ b.indent--
+}
+
+func (b *sourceBuffer) emit(a ...interface{}) {
+ emit(&b.b, b.indent, a...)
+}
+
+func (b *sourceBuffer) emitNoIndent(a ...interface{}) {
+ emit(&b.b, 0 /*indent*/, a...)
+}
+
+func (b *sourceBuffer) inIndent(body func()) {
+ b.incIndent()
+ body()
+ b.decIndent()
+}
+
+func (b *sourceBuffer) write(out io.Writer) error {
+ _, err := fmt.Fprint(out, b.b.String())
+ return err
+}
+
+// Write implements io.Writer.Write.
+func (b *sourceBuffer) Write(buf []byte) (int, error) {
+ return (b.b.Write(buf))
+}
+
+// importStmt represents a single import statement.
+type importStmt struct {
+ // Local name of the imported package.
+ name string
+ // Import path.
+ path string
+ // Indicates whether the local name is an alias, or simply the final
+ // component of the path.
+ aliased bool
+ // Indicates whether this import was referenced by generated code.
+ used bool
+}
+
+func newImport(p string) *importStmt {
+ name := path.Base(p)
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: false,
+ }
+}
+
+func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
+ name := path.Base(p)
+ if name == "" || name == "/" || name == "." {
+ panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
+ f.Position(spec.Path.Pos()), name))
+ }
+ if spec.Name != nil {
+ name = spec.Name.Name
+ }
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: spec.Name != nil,
+ }
+}
+
+func (i *importStmt) String() string {
+ if i.aliased {
+ return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ }
+ return fmt.Sprintf("\"%s\"", i.path)
+}
+
+func (i *importStmt) markUsed() {
+ i.used = true
+}
+
+func (i *importStmt) equivalent(other *importStmt) bool {
+ return i == other
+}
+
+// importTable represents a collection of importStmts.
+type importTable struct {
+ // Map of imports and whether they should be copied to the output.
+ is map[string]*importStmt
+}
+
+func newImportTable() *importTable {
+ return &importTable{
+ is: make(map[string]*importStmt),
+ }
+}
+
+// Merges import statements from other into i. Collisions in import statements
+// result in a panic.
+func (i *importTable) merge(other *importTable) {
+ for name, im := range other.is {
+ if dup, ok := i.is[name]; ok && dup.equivalent(im) {
+ panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
+ }
+
+ i.is[name] = im
+ }
+}
+
+func (i *importTable) add(s string) *importStmt {
+ n := newImport(s)
+ i.is[n.name] = n
+ return n
+}
+
+func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ n := newImportFromSpec(spec, f)
+ i.is[n.name] = n
+ return n
+}
+
+// Marks the import named n as used. If no such import is in the table, returns
+// false.
+func (i *importTable) markUsed(n string) bool {
+ if n, ok := i.is[n]; ok {
+ n.markUsed()
+ return true
+ }
+ return false
+}
+
+func (i *importTable) clear() {
+ for _, i := range i.is {
+ i.used = false
+ }
+}
+
+func (i *importTable) write(out io.Writer) error {
+ if len(i.is) == 0 {
+ // Nothing to import, we're done.
+ return nil
+ }
+
+ imports := make([]string, 0, len(i.is))
+ for _, i := range i.is {
+ if i.used {
+ imports = append(imports, i.String())
+ }
+ }
+ sort.Strings(imports)
+
+ var b sourceBuffer
+ b.emit("import (\n")
+ b.incIndent()
+ for _, i := range imports {
+ b.emit("%s\n", i)
+ }
+ b.decIndent()
+ b.emit(")\n\n")
+
+ return b.write(out)
+}
diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go
new file mode 100644
index 000000000..3d12eb93c
--- /dev/null
+++ b/tools/go_marshal/main.go
@@ -0,0 +1,73 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// go_marshal is a code generation utility for automatically generating code to
+// marshal go data structures to memory.
+//
+// This binary is typically run as part of the build process, and is invoked by
+// the go_marshal bazel rule defined in defs.bzl.
+//
+// See README.md.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/go_marshal/gomarshal"
+)
+
+var (
+ pkg = flag.String("pkg", "", "output package")
+ output = flag.String("output", "", "output file")
+ outputTest = flag.String("output_test", "", "output file for tests")
+ imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
+ declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on")
+)
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ if len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ if *pkg == "" {
+ flag.Usage()
+ fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n")
+ os.Exit(1)
+ }
+
+ var extraImports []string
+ if len(*imports) > 0 {
+ // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus
+ // we check for an empty imports list to avoid emitting an empty string
+ // as an import.
+ extraImports = strings.Split(*imports, ",")
+ }
+ g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports)
+ if err != nil {
+ panic(err)
+ }
+
+ if err := g.Run(); err != nil {
+ panic(err)
+ }
+}
diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD
new file mode 100644
index 000000000..47dda97a1
--- /dev/null
+++ b/tools/go_marshal/marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "marshal",
+ srcs = [
+ "marshal.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
new file mode 100644
index 000000000..a313a27ed
--- /dev/null
+++ b/tools/go_marshal/marshal/marshal.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal defines the Marshallable interface for
+// serialize/deserializing go data structures to/from memory, according to the
+// Linux ABI.
+//
+// Implementations of this interface are typically automatically generated by
+// tools/go_marshal. See the go_marshal README for details.
+package marshal
+
+// Marshallable represents a type that can be marshalled to and from memory.
+type Marshallable interface {
+ // SizeBytes is the size of the memory representation of a type in
+ // marshalled form.
+ SizeBytes() int
+
+ // MarshalBytes serializes a copy of a type to dst. dst must be at least
+ // SizeBytes() long.
+ MarshalBytes(dst []byte)
+
+ // UnmarshalBytes deserializes a type from src. src must be at least
+ // SizeBytes() long.
+ UnmarshalBytes(src []byte)
+
+ // Packed returns true if the marshalled size of the type is the same as the
+ // size it occupies in memory. This happens when the type has no fields
+ // starting at unaligned addresses (should always be true by default for ABI
+ // structs, verified by automatically generated tests when using
+ // go_marshal), and has no fields marked `marshal:"unaligned"`.
+ Packed() bool
+
+ // MarshalUnsafe serializes a type by bulk copying its in-memory
+ // representation to the dst buffer. This is only safe to do when the type
+ // has no implicit padding, see Marshallable.Packed. When Packed would
+ // return false, MarshalUnsafe should fall back to the safer but slower
+ // MarshalBytes.
+ MarshalUnsafe(dst []byte)
+
+ // UnmarshalUnsafe deserializes a type directly to the underlying memory
+ // allocated for the object by the runtime.
+ //
+ // This allows much faster unmarshalling of types which have no implicit
+ // padding, see Marshallable.Packed. When Packed would return false,
+ // UnmarshalUnsafe should fall back to the safer but slower unmarshal
+ // mechanism implemented in UnmarshalBytes (usually by calling
+ // UnmarshalBytes directly).
+ UnmarshalUnsafe(src []byte)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
new file mode 100644
index 000000000..fa82f8e9b
--- /dev/null
+++ b/tools/go_marshal/test/BUILD
@@ -0,0 +1,31 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+package_group(
+ name = "gomarshal_test",
+ packages = [
+ "//tools/go_marshal/test/...",
+ ],
+)
+
+go_test(
+ name = "benchmark_test",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ ":test",
+ "//pkg/binary",
+ "//pkg/sentry/usermem",
+ "//tools/go_marshal/analysis",
+ ],
+)
+
+go_library(
+ name = "test",
+ testonly = 1,
+ srcs = ["test.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test",
+ deps = ["//tools/go_marshal/test/external"],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
new file mode 100644
index 000000000..e70db06d8
--- /dev/null
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -0,0 +1,178 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package benchmark_test
+
+import (
+ "bytes"
+ encbin "encoding/binary"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ test "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// Marshalling using the standard encoding/binary package.
+func BenchmarkEncodingBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := encbin.Size(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := bytes.NewBuffer(make([]byte, size))
+ buf.Reset()
+ if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil {
+ b.Error("Write:", err)
+ }
+ if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil {
+ b.Error("Read:", err)
+ }
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling using the sentry's binary.Marshal.
+func BenchmarkBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling field-by-field with manually-written code.
+func BenchmarkMarshalManual(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, s1.SizeBytes())
+
+ // Marshal
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, 0)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec))
+
+ // Unmarshal
+ s2.Dev = usermem.ByteOrder.Uint64(buf[0:8])
+ s2.Ino = usermem.ByteOrder.Uint64(buf[8:16])
+ s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24])
+ s2.Mode = usermem.ByteOrder.Uint32(buf[24:28])
+ s2.UID = usermem.ByteOrder.Uint32(buf[28:32])
+ s2.GID = usermem.ByteOrder.Uint32(buf[32:36])
+ // Padding: buf[36:40]
+ s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48])
+ s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56]))
+ s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64]))
+ s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72]))
+ s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80]))
+ s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88]))
+ s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96]))
+ s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104]))
+ s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112]))
+ s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120]))
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal safe API.
+func BenchmarkGoMarshalSafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalBytes(buf)
+ s2.UnmarshalBytes(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal unsafe API.
+func BenchmarkGoMarshalUnsafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalUnsafe(buf)
+ s2.UnmarshalUnsafe(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD
new file mode 100644
index 000000000..8fb43179b
--- /dev/null
+++ b/tools/go_marshal/test/external/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "external",
+ testonly = 1,
+ srcs = ["external.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external",
+ visibility = ["//tools/go_marshal/test:gomarshal_test"],
+)
diff --git a/runsc/test/root/root.go b/tools/go_marshal/test/external/external.go
index 349c752cc..4be3722f3 100644
--- a/runsc/test/root/root.go
+++ b/tools/go_marshal/test/external/external.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,5 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package root is empty. See chroot_test.go for description.
-package root
+// Package external defines types we can import for testing.
+package external
+
+// External is a public Marshallable type for use in testing.
+//
+// +marshal
+type External struct {
+ j int64
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
new file mode 100644
index 000000000..8de02d707
--- /dev/null
+++ b/tools/go_marshal/test/test.go
@@ -0,0 +1,105 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test contains data structures for testing the go_marshal tool.
+package test
+
+import (
+ // We're intentionally using a package name alias here even though it's not
+ // necessary to test the code generator's ability to handle package aliases.
+ ex "gvisor.dev/gvisor/tools/go_marshal/test/external"
+)
+
+// Type1 is a test data type.
+//
+// +marshal
+type Type1 struct {
+ a Type2
+ x, y int64 // Multiple field names.
+ b byte `marshal:"unaligned"` // Short field.
+ c uint64
+ _ uint32 // Unnamed scalar field.
+ _ [6]byte // Unnamed vector field, typical padding.
+ _ [2]byte
+ xs [8]int32
+ as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects.
+ ss Type3
+}
+
+// Type2 is a test data type.
+//
+// +marshal
+type Type2 struct {
+ n int64
+ c byte
+ _ [7]byte
+ m int64
+ a int64
+}
+
+// Type3 is a test data type.
+//
+// +marshal
+type Type3 struct {
+ s int64
+ x ex.External // Type defined in another package.
+}
+
+// Type4 is a test data type.
+//
+// +marshal
+type Type4 struct {
+ c byte
+ x int64 `marshal:"unaligned"`
+ d byte
+ _ [7]byte
+}
+
+// Type5 is a test data type.
+//
+// +marshal
+type Type5 struct {
+ n int64
+ t Type4
+ m int64
+}
+
+// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
+type Timespec struct {
+ Sec int64
+ Nsec int64
+}
+
+// Stat represents struct stat.
+//
+// +marshal
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ _ int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [3]int64
+}
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
index aeba197e2..3ce36c1c8 100644
--- a/tools/go_stateify/defs.bzl
+++ b/tools/go_stateify/defs.bzl
@@ -35,7 +35,7 @@ go_library(
)
"""
-load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library")
def _go_stateify_impl(ctx):
"""Implementation for the stateify tool."""
@@ -60,28 +60,57 @@ def _go_stateify_impl(ctx):
executable = ctx.executable._tool,
)
-# Generates save and restore logic from a set of Go files.
-#
-# Args:
-# name: the name of the rule.
-# srcs: the input source files. These files should include all structs in the package that need to be saved.
-# imports: an optional list of extra non-aliased, Go-style absolute import paths.
-# out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library.
-# package: the package name for the input sources.
go_stateify = rule(
implementation = _go_stateify_impl,
+ doc = "Generates save and restore logic from a set of Go files.",
attrs = {
- "srcs": attr.label_list(mandatory = True, allow_files = True),
- "imports": attr.string_list(mandatory = False),
- "package": attr.string(mandatory = True),
- "out": attr.output(mandatory = True),
- "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_stateify:stateify")),
+ "srcs": attr.label_list(
+ doc = """
+The input source files. These files should include all structs in the package
+that need to be saved.
+""",
+ mandatory = True,
+ allow_files = True,
+ ),
+ "imports": attr.string_list(
+ doc = """
+An optional list of extra non-aliased, Go-style absolute import paths required
+for statified types.
+""",
+ mandatory = False,
+ ),
+ "package": attr.string(
+ doc = "The package name for the input sources.",
+ mandatory = True,
+ ),
+ "out": attr.output(
+ doc = """
+The name of the generated file output. This must not conflict with any other
+files and must be added to the srcs of the relevant go_library.
+""",
+ mandatory = True,
+ ),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/go_stateify:stateify"),
+ ),
"_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"),
},
)
def go_library(name, srcs, deps = [], imports = [], **kwargs):
- """wraps the standard go_library and does stateification."""
+ """Standard go_library wrapped which generates state source files.
+
+ Args:
+ name: the name of the go_library rule.
+ srcs: sources of the go_library. Each will be processed for stateify
+ annotations.
+ deps: dependencies for the go_library.
+ imports: an optional list of extra non-aliased, Go-style absolute import
+ paths required for stateified types.
+ **kwargs: passed to go_library.
+ """
if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs:
# Only do stateification for non-state packages without manual autogen.
go_stateify(
@@ -105,9 +134,3 @@ def go_library(name, srcs, deps = [], imports = [], **kwargs):
deps = all_deps,
**kwargs
)
-
-def go_test(**kwargs):
- """Wraps the standard go_test."""
- _go_test(
- **kwargs
- )
diff --git a/tools/image_build.sh b/tools/image_build.sh
new file mode 100755
index 000000000..9b20a740d
--- /dev/null
+++ b/tools/image_build.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script is responsible for building a new GCP image that: 1) has nested
+# virtualization enabled, and 2) has been completely set up with the
+# image_setup.sh script. This script should be idempotent, as we memoize the
+# setup script with a hash and check for that name.
+#
+# The GCP project name should be defined via a gcloud config.
+
+set -xeo pipefail
+
+# Parameters.
+declare -r ZONE=${ZONE:-us-central1-f}
+declare -r USERNAME=${USERNAME:-test}
+declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
+declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
+
+# Random names.
+declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
+declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
+declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
+
+# Hashes inputs.
+declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@")
+declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
+declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH}
+
+# Does the image already exist? Skip the build.
+declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
+if ! [[ -z "${existing}" ]]; then
+ echo "${existing}"
+ exit 0
+fi
+
+# Set the zone for all actions.
+gcloud config set compute/zone "${ZONE}"
+
+# Start a unique instance. Note that this instance will have a unique persistent
+# disk as it's boot disk with the same name as the instance.
+gcloud compute instances create \
+ --quiet \
+ --image-project "${IMAGE_PROJECT}" \
+ --image-family "${IMAGE_FAMILY}" \
+ --boot-disk-size "200GB" \
+ "${INSTANCE_NAME}"
+function cleanup {
+ gcloud compute instances delete --quiet "${INSTANCE_NAME}"
+}
+trap cleanup EXIT
+
+# Wait for the instance to become available.
+declare attempts=0
+while [[ "${attempts}" -lt 30 ]]; do
+ attempts=$((${attempts}+1))
+ if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then
+ break
+ fi
+done
+if [[ "${attempts}" -ge 30 ]]; then
+ echo "too many attempts: failed"
+ exit 1
+fi
+
+# Run the install scripts provided.
+for arg; do
+ gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}"
+done
+
+# Stop the instance; required before creating an image.
+gcloud compute instances stop --quiet "${INSTANCE_NAME}"
+
+# Create a snapshot of the instance disk.
+gcloud compute disks snapshot \
+ --quiet \
+ --zone="${ZONE}" \
+ --snapshot-names="${SNAPSHOT_NAME}" \
+ "${INSTANCE_NAME}"
+
+# Create the disk image.
+gcloud compute images create \
+ --quiet \
+ --source-snapshot="${SNAPSHOT_NAME}" \
+ --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
+ "${IMAGE_NAME}"
diff --git a/tools/make_repository.sh b/tools/make_repository.sh
new file mode 100755
index 000000000..071f72b74
--- /dev/null
+++ b/tools/make_repository.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Parse arguments. We require more than two arguments, which are the private
+# keyring, the e-mail associated with the signer, and the list of packages.
+if [ "$#" -le 3 ]; then
+ echo "usage: $0 <private-key> <signer-email> <component> <packages...>"
+ exit 1
+fi
+declare -r private_key=$(readlink -e "$1")
+declare -r signer="$2"
+declare -r component="$3"
+shift; shift; shift
+
+# Verbose from this point.
+set -xeo pipefail
+
+# Create a temporary working directory. We don't remove this, as we ultimately
+# print this result and allow the caller to copy wherever they would like.
+declare -r tmpdir=$(mktemp -d /tmp/repoXXXXXX)
+
+# Create a temporary keyring, and ensure it is cleaned up.
+declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg)
+cleanup() {
+ rm -f "${keyring}"
+}
+trap cleanup EXIT
+gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2
+
+# Copy the packages, and ensure permissions are correct.
+for pkg in "$@"; do
+ name=$(basename "${pkg}" .deb)
+ name=$(basename "${name}" .changes)
+ arch=${name##*_}
+ if [[ "${name}" == "${arch}" ]]; then
+ continue # Not a regular package.
+ fi
+ mkdir -p "${tmpdir}"/"${component}"/binary-"${arch}"
+ cp -a "${pkg}" "${tmpdir}"/"${component}"/binary-"${arch}"
+done
+find "${tmpdir}" -type f -exec chmod 0644 {} \;
+
+# Ensure there are no symlinks hanging around; these may be remnants of the
+# build process. They may be useful for other things, but we are going to build
+# an index of the actual packages here.
+find "${tmpdir}" -type l -exec rm -f {} \;
+
+# Sign all packages.
+for file in "${tmpdir}"/"${component}"/binary-*/*.deb; do
+ dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2
+done
+
+# Build the package list.
+for dir in "${tmpdir}"/"${component}"/binary-*; do
+ (cd "${dir}" && apt-ftparchive packages . | gzip > Packages.gz)
+done
+
+# Build the release list.
+(cd "${tmpdir}" && apt-ftparchive release . > Release)
+
+# Sign the release.
+(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release >&2)
+(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release >&2)
+
+# Show the results.
+echo "${tmpdir}"
diff --git a/tools/run_build.sh b/tools/run_build.sh
deleted file mode 100755
index 7f6ada480..000000000
--- a/tools/run_build.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Fail on any error.
-set -e
-# Display commands to stderr.
-set -x
-
-# Install the latest version of Bazel and log the version.
-(which use_bazel.sh && use_bazel.sh latest) || which bazel
-bazel version
-
-# Switch into the workspace.
-if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
- cd git/repo
-elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
- cd github/repo
-fi
-
-# Build runsc.
-bazel build -c opt --strip=never //runsc
-
-# Move the runsc binary into "latest" directory, and also a directory with the
-# current date.
-if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
- latest_dir="${KOKORO_ARTIFACTS_DIR}"/latest
- today_dir="${KOKORO_ARTIFACTS_DIR}"/"$(date -Idate)"
- runsc="bazel-bin/runsc/linux_amd64_pure/runsc"
-
- mkdir -p "${latest_dir}" "${today_dir}"
- cp "${runsc}" "${latest_dir}"
- cp "${runsc}" "${today_dir}"
-
- sha512sum "${latest_dir}"/runsc | awk '{print $1 " runsc"}' > "${latest_dir}"/runsc.sha512
- cp "${latest_dir}"/runsc.sha512 "${today_dir}"/runsc.sha512
-fi
diff --git a/tools/run_tests.sh b/tools/run_tests.sh
deleted file mode 100755
index 3e7c4394d..000000000
--- a/tools/run_tests.sh
+++ /dev/null
@@ -1,302 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Fail on any error. Treat unset variables as error. Print commands as executed.
-set -eux
-
-###################
-# GLOBAL ENV VARS #
-###################
-
-if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
- readonly WORKSPACE_DIR="${PWD}/git/repo"
-elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
- readonly WORKSPACE_DIR="${PWD}/github/repo"
-else
- readonly WORKSPACE_DIR="${PWD}"
-fi
-
-# Used to configure RBE.
-readonly CLOUD_PROJECT_ID="gvisor-rbe"
-readonly RBE_PROJECT_ID="projects/${CLOUD_PROJECT_ID}/instances/default_instance"
-
-# Random runtime name to avoid collisions.
-readonly RUNTIME="runsc_test_$((RANDOM))"
-
-# Packages that will be built and tested.
-readonly BUILD_PACKAGES=("//...")
-readonly TEST_PACKAGES=("//pkg/..." "//runsc/..." "//tools/...")
-
-#######################
-# BAZEL CONFIGURATION #
-#######################
-
-# Install the latest version of Bazel and log the version.
-(which use_bazel.sh && use_bazel.sh 0.28.0) || which bazel
-bazel version
-
-# Load the kvm module.
-sudo -n -E modprobe kvm
-
-# General Bazel build/test flags.
-BAZEL_BUILD_FLAGS=(
- "--show_timestamps"
- "--test_output=errors"
- "--keep_going"
- "--verbose_failures=true"
-)
-
-# Bazel build/test for RBE, a super-set of BAZEL_BUILD_FLAGS.
-BAZEL_BUILD_RBE_FLAGS=(
- "${BAZEL_BUILD_FLAGS[@]}"
- "--config=remote"
- "--project_id=${CLOUD_PROJECT_ID}"
- "--remote_instance_name=${RBE_PROJECT_ID}"
-)
-if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
- BAZEL_BUILD_RBE_FLAGS=(
- "${BAZEL_BUILD_RBE_FLAGS[@]}"
- "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
- )
-fi
-
-####################
-# Helper Functions #
-####################
-
-sanity_checks() {
- cd ${WORKSPACE_DIR}
- bazel run //:gazelle -- update-repos -from_file=go.mod
- git diff --exit-code WORKSPACE
-}
-
-build_everything() {
- FLAVOR="${1}"
-
- cd ${WORKSPACE_DIR}
- bazel build \
- -c "${FLAVOR}" "${BAZEL_BUILD_RBE_FLAGS[@]}" \
- "${BUILD_PACKAGES[@]}"
-}
-
-build_runsc_debian() {
- cd ${WORKSPACE_DIR}
-
- # TODO(b/135475885): pkg_deb is incompatible with Python3.
- # https://github.com/bazelbuild/bazel/issues/8443
- bazel build --host_force_python=py2 runsc:runsc-debian
-}
-
-# Run simple tests runs the tests that require no special setup or
-# configuration.
-run_simple_tests() {
- cd ${WORKSPACE_DIR}
- bazel test \
- "${BAZEL_BUILD_FLAGS[@]}" \
- "${TEST_PACKAGES[@]}"
-}
-
-install_runtime() {
- cd ${WORKSPACE_DIR}
- sudo -n ${WORKSPACE_DIR}/runsc/test/install.sh --runtime ${RUNTIME}
-}
-
-install_helper() {
- PACKAGE="${1}"
- TAG="${2}"
- GOPATH="${3}"
-
- # Clone the repository.
- mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \
- git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}"
-
- # Checkout and build the repository.
- (cd "${GOPATH}"/src/"${PACKAGE}" && \
- git checkout "${TAG}" && \
- GOPATH="${GOPATH}" make && \
- sudo -n -E env GOPATH="${GOPATH}" make install)
-}
-
-# Install dependencies for the crictl tests.
-install_crictl_test_deps() {
- sudo -n -E apt-get update
- sudo -n -E apt-get install -y btrfs-tools libseccomp-dev
-
- # Install containerd & cri-tools.
- GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
- install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}"
- install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}"
-
- # Install gvisor-containerd-shim.
- local latest=/tmp/gvisor-containerd-shim-latest
- local shim_path=/tmp/gvisor-containerd-shim
- wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/latest -O ${latest}
- wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
- chmod +x ${shim_path}
- sudo -n -E mv ${shim_path} /usr/local/bin
-
- # Configure containerd-shim.
- local shim_config_path=/etc/containerd
- local shim_config_tmp_path=/tmp/gvisor-containerd-shim.toml
- sudo -n -E mkdir -p ${shim_config_path}
- cat > ${shim_config_tmp_path} <<-EOF
- runc_shim = "/usr/local/bin/containerd-shim"
-
- [runsc_config]
- debug = "true"
- debug-log = "/tmp/runsc-logs/"
- strace = "true"
- file-access = "shared"
-EOF
- sudo mv ${shim_config_tmp_path} ${shim_config_path}
-
- # Configure CNI.
- (cd "${GOPATH}" && sudo -n -E env PATH="${PATH}" GOPATH="${GOPATH}" \
- src/github.com/containerd/containerd/script/setup/install-cni)
-}
-
-# Run the tests that require docker.
-run_docker_tests() {
- cd ${WORKSPACE_DIR}
-
- # Run tests with a default runtime (runc).
- bazel test \
- "${BAZEL_BUILD_FLAGS[@]}" \
- --test_env=RUNSC_RUNTIME="" \
- --test_output=all \
- //runsc/test/image:image_test
-
- # These names are used to exclude tests not supported in certain
- # configuration, e.g. save/restore not supported with hostnet.
- # Run runsc tests with docker that are tagged manual.
- #
- # The --nocache_test_results option is used here to eliminate cached results
- # from the previous run for the runc runtime.
- bazel test \
- "${BAZEL_BUILD_FLAGS[@]}" \
- --test_env=RUNSC_RUNTIME="${RUNTIME}" \
- --test_output=all \
- --nocache_test_results \
- --test_output=streamed \
- //runsc/test/integration:integration_test \
- //runsc/test/integration:integration_test_hostnet \
- //runsc/test/integration:integration_test_overlay \
- //runsc/test/integration:integration_test_kvm \
- //runsc/test/image:image_test \
- //runsc/test/image:image_test_overlay \
- //runsc/test/image:image_test_hostnet \
- //runsc/test/image:image_test_kvm
-}
-
-# Run the tests that require root.
-run_root_tests() {
- cd ${WORKSPACE_DIR}
- bazel build //runsc/test/root:root_test
- local root_test=$(find -L ./bazel-bin/ -executable -type f -name root_test | grep __main__)
- if [[ ! -f "${root_test}" ]]; then
- echo "root_test executable not found"
- exit 1
- fi
- sudo -n -E RUNSC_RUNTIME="${RUNTIME}" RUNSC_EXEC=/tmp/"${RUNTIME}"/runsc ${root_test}
-}
-
-# Run syscall unit tests.
-run_syscall_tests() {
- cd ${WORKSPACE_DIR}
- bazel test "${BAZEL_BUILD_RBE_FLAGS[@]}" \
- --test_tag_filters=runsc_ptrace //test/syscalls/...
-}
-
-run_runsc_do_tests() {
- local runsc=$(find bazel-bin/runsc -type f -executable -name "runsc" | head -n1)
-
- # run runsc do without root privileges.
- ${runsc} --rootless do true
- ${runsc} --rootless --network=none do true
-
- # run runsc do with root privileges.
- sudo -n -E ${runsc} do true
-}
-
-# Find and rename all test xml and log files so that Sponge can pick them up.
-# XML files must be named sponge_log.xml, and log files must be named
-# sponge_log.log. We move all such files into KOKORO_ARTIFACTS_DIR, in a
-# subdirectory named with the test name.
-upload_test_artifacts() {
- # Skip if no kokoro directory.
- [[ -v KOKORO_ARTIFACTS_DIR ]] || return
-
- cd ${WORKSPACE_DIR}
- find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
- tar --create --files-from - --transform 's/test\./sponge_log./' |
- tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
- if [[ -d "/tmp/${RUNTIME}/logs" ]]; then
- tar --create --gzip "--file=${KOKORO_ARTIFACTS_DIR}/runsc-logs.tar.gz" -C /tmp/ ${RUNTIME}/logs
- fi
-}
-
-# Finish runs at exit, even in the event of an error, and uploads all test
-# artifacts.
-finish() {
- # Grab the last exit code, we will return it.
- local exit_code=${?}
- upload_test_artifacts
- exit ${exit_code}
-}
-
-# Run bazel in a docker container
-build_in_docker() {
- cd ${WORKSPACE_DIR}
- bazel clean
- bazel shutdown
- make
- make runsc
- make bazel-shutdown
-}
-
-########
-# MAIN #
-########
-
-main() {
- # Register finish to run at exit.
- trap finish EXIT
-
- # Build and run the simple tests.
- sanity_checks
- build_everything opt
- run_simple_tests
-
- # So far so good. Install more deps and run the integration tests.
- install_runtime
- install_crictl_test_deps
- run_docker_tests
- run_root_tests
-
- run_syscall_tests
- run_runsc_do_tests
-
- build_runsc_debian
-
- # Build other flavors too.
- build_everything dbg
-
- build_in_docker
- # No need to call "finish" here, it will happen at exit.
-}
-
-# Kick it off.
-main
diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh
index 64a905fc9..fb09ff331 100755
--- a/tools/workspace_status.sh
+++ b/tools/workspace_status.sh
@@ -14,4 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-echo VERSION $(git describe --always --tags --abbrev=12 --dirty)
+# The STABLE_ prefix will trigger a re-link if it changes.
+echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty)