summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.travis.yml4
-rw-r--r--BUILD23
-rw-r--r--CONTRIBUTING.md3
-rw-r--r--Dockerfile2
-rw-r--r--WORKSPACE131
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py8
-rw-r--r--benchmarks/workloads/ruby/Gemfile.lock22
-rw-r--r--benchmarks/workloads/ruby_template/Gemfile.lock6
-rw-r--r--go.mod33
-rw-r--r--go.sum33
-rw-r--r--kokoro/benchmark_tests.cfg8
-rw-r--r--kokoro/iptables_tests.cfg2
-rw-r--r--kokoro/packetimpact_tests.cfg9
-rw-r--r--pkg/abi/linux/elf.go3
-rw-r--r--pkg/abi/linux/file.go8
-rw-r--r--pkg/abi/linux/netfilter.go68
-rw-r--r--pkg/bits/bits_template.go8
-rw-r--r--pkg/bits/uint64_test.go18
-rw-r--r--pkg/buffer/BUILD43
-rw-r--r--pkg/buffer/buffer.go94
-rw-r--r--pkg/buffer/safemem.go129
-rw-r--r--pkg/buffer/safemem_test.go170
-rw-r--r--pkg/buffer/view.go390
-rw-r--r--pkg/buffer/view_test.go467
-rw-r--r--pkg/buffer/view_unsafe.go25
-rw-r--r--pkg/gate/gate_test.go3
-rw-r--r--pkg/ilist/list.go41
-rw-r--r--pkg/rand/rand_linux.go17
-rw-r--r--pkg/safemem/seq_test.go21
-rw-r--r--pkg/safemem/seq_unsafe.go3
-rw-r--r--pkg/seccomp/seccomp_test.go2
-rw-r--r--pkg/sentry/arch/arch_aarch64.go39
-rw-r--r--pkg/sentry/arch/arch_arm64.go30
-rw-r--r--pkg/sentry/arch/signal_arm64.go21
-rw-r--r--pkg/sentry/arch/syscalls_arm64.go10
-rw-r--r--pkg/sentry/control/pprof.go34
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/dev.go7
-rw-r--r--pkg/sentry/fs/dev/net_tun.go7
-rw-r--r--pkg/sentry/fs/dirent.go133
-rw-r--r--pkg/sentry/fs/dirent_cache.go2
-rw-r--r--pkg/sentry/fs/file_overlay_test.go84
-rw-r--r--pkg/sentry/fs/fsutil/inode.go4
-rw-r--r--pkg/sentry/fs/host/BUILD4
-rw-r--r--pkg/sentry/fs/host/control.go2
-rw-r--r--pkg/sentry/fs/host/descriptor.go37
-rw-r--r--pkg/sentry/fs/host/descriptor_state.go2
-rw-r--r--pkg/sentry/fs/host/descriptor_test.go4
-rw-r--r--pkg/sentry/fs/host/file.go4
-rw-r--r--pkg/sentry/fs/host/fs.go339
-rw-r--r--pkg/sentry/fs/host/fs_test.go380
-rw-r--r--pkg/sentry/fs/host/host.go59
-rw-r--r--pkg/sentry/fs/host/inode.go141
-rw-r--r--pkg/sentry/fs/host/inode_state.go32
-rw-r--r--pkg/sentry/fs/host/inode_test.go66
-rw-r--r--pkg/sentry/fs/host/ioctl_unsafe.go4
-rw-r--r--pkg/sentry/fs/host/tty.go5
-rw-r--r--pkg/sentry/fs/host/util.go88
-rw-r--r--pkg/sentry/fs/host/util_unsafe.go41
-rw-r--r--pkg/sentry/fs/inotify.go5
-rw-r--r--pkg/sentry/fs/mounts.go13
-rw-r--r--pkg/sentry/fs/proc/net.go53
-rw-r--r--pkg/sentry/fs/proc/proc.go2
-rw-r--r--pkg/sentry/fs/proc/task.go130
-rw-r--r--pkg/sentry/fsbridge/vfs.go2
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go10
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go49
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go52
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go14
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go8
-rw-r--r--pkg/sentry/fsimpl/host/BUILD34
-rw-r--r--pkg/sentry/fsimpl/host/host.go616
-rw-r--r--pkg/sentry/fsimpl/host/ioctl_unsafe.go56
-rw-r--r--pkg/sentry/fsimpl/host/tty.go379
-rw-r--r--pkg/sentry/fsimpl/host/util.go66
-rw-r--r--pkg/sentry/fsimpl/host/util_unsafe.go (renamed from pkg/sentry/kernel/pipe/buffer_test.go)26
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go9
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go23
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go22
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go87
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go15
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go7
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD6
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go25
-rw-r--r--pkg/sentry/fsimpl/proc/task.go79
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go287
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go296
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go (renamed from pkg/sentry/fsimpl/proc/tasks_net.go)15
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go63
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go85
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go122
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go8
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD2
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go27
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go10
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go7
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go36
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go13
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go43
-rw-r--r--pkg/sentry/inet/namespace.go5
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go2
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go13
-rw-r--r--pkg/sentry/kernel/fd_table.go24
-rw-r--r--pkg/sentry/kernel/kernel.go44
-rw-r--r--pkg/sentry/kernel/pipe/BUILD18
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go115
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go118
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go25
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go1
-rw-r--r--pkg/sentry/kernel/sessions.go2
-rw-r--r--pkg/sentry/kernel/task.go16
-rw-r--r--pkg/sentry/kernel/task_clone.go3
-rw-r--r--pkg/sentry/kernel/task_run.go41
-rw-r--r--pkg/sentry/kernel/thread_group.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s8
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go25
-rw-r--r--pkg/sentry/platform/kvm/kvm.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go32
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go24
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_amd64.go14
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go62
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go12
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go80
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go11
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go65
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go27
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids.go (renamed from pkg/sentry/platform/ring0/pagetables/pcids_x86.go)2
-rw-r--r--pkg/sentry/sighandling/sighandling.go5
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go14
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go209
-rw-r--r--pkg/sentry/socket/netfilter/targets.go11
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go11
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go13
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go70
-rw-r--r--pkg/sentry/socket/netstack/stack.go76
-rw-r--r--pkg/sentry/syscalls/linux/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go23
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_amd64.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_arm64.go45
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go15
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go95
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_amd64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_arm64.go46
-rw-r--r--pkg/sentry/vfs/BUILD5
-rw-r--r--pkg/sentry/vfs/anonfs.go25
-rw-r--r--pkg/sentry/vfs/file_description.go3
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go9
-rw-r--r--pkg/sentry/vfs/filesystem.go8
-rw-r--r--pkg/sentry/vfs/mount.go42
-rw-r--r--pkg/sentry/vfs/options.go16
-rw-r--r--pkg/sentry/vfs/permissions.go54
-rw-r--r--pkg/sentry/vfs/resolving_path.go16
-rw-r--r--pkg/sentry/vfs/vfs.go24
-rw-r--r--pkg/sentry/watchdog/watchdog.go6
-rw-r--r--pkg/sync/aliases.go5
-rw-r--r--pkg/tcpip/BUILD2
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/checker/checker.go106
-rw-r--r--pkg/tcpip/header/ipv4.go5
-rw-r--r--pkg/tcpip/header/tcp.go42
-rw-r--r--pkg/tcpip/iptables/BUILD18
-rw-r--r--pkg/tcpip/iptables/targets.go65
-rw-r--r--pkg/tcpip/link/channel/channel.go14
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go20
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go86
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go9
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go3
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/link/loopback/loopback.go8
-rw-r--r--pkg/tcpip/link/muxed/injectable.go6
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go26
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go10
-rw-r--r--pkg/tcpip/link/tun/device.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable.go6
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go18
-rw-r--r--pkg/tcpip/network/arp/arp.go12
-rw-r--r--pkg/tcpip/network/arp/arp_test.go2
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go8
-rw-r--r--pkg/tcpip/network/ip_test.go24
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go9
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go21
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go18
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go10
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go14
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go10
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go8
-rw-r--r--pkg/tcpip/stack/BUILD12
-rw-r--r--pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go39
-rw-r--r--pkg/tcpip/stack/forwarder.go131
-rw-r--r--pkg/tcpip/stack/forwarder_test.go635
-rw-r--r--pkg/tcpip/stack/iptables.go (renamed from pkg/tcpip/iptables/iptables.go)66
-rw-r--r--pkg/tcpip/stack/iptables_targets.go144
-rw-r--r--pkg/tcpip/stack/iptables_types.go (renamed from pkg/tcpip/iptables/types.go)18
-rw-r--r--pkg/tcpip/stack/ndp.go329
-rw-r--r--pkg/tcpip/stack/ndp_test.go207
-rw-r--r--pkg/tcpip/stack/nic.go185
-rw-r--r--pkg/tcpip/stack/nic_test.go3
-rw-r--r--pkg/tcpip/stack/packet_buffer.go (renamed from pkg/tcpip/packet_buffer.go)6
-rw-r--r--pkg/tcpip/stack/packet_buffer_state.go (renamed from pkg/tcpip/packet_buffer_state.go)2
-rw-r--r--pkg/tcpip/stack/rand.go40
-rw-r--r--pkg/tcpip/stack/registration.go35
-rw-r--r--pkg/tcpip/stack/route.go6
-rw-r--r--pkg/tcpip/stack/stack.go72
-rw-r--r--pkg/tcpip/stack/stack_test.go546
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go281
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go16
-rw-r--r--pkg/tcpip/stack/transport_test.go27
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go34
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/BUILD1
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go30
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go9
-rw-r--r--pkg/tcpip/transport/tcp/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/accept.go155
-rw-r--r--pkg/tcpip/transport/tcp/connect.go225
-rw-r--r--pkg/tcpip/transport/tcp/connect_unsafe.go30
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go11
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go623
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go52
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go10
-rw-r--r--pkg/tcpip/transport/tcp/segment.go9
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go17
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go8
-rw-r--r--pkg/tcpip/transport/tcp/snd.go11
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go42
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go11
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go3
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/udp/protocol.go6
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go4
-rw-r--r--runsc/boot/config.go4
-rw-r--r--runsc/boot/controller.go13
-rw-r--r--runsc/boot/filter/config.go9
-rw-r--r--runsc/cmd/boot.go75
-rw-r--r--runsc/cmd/chroot.go2
-rw-r--r--runsc/cmd/debug.go64
-rw-r--r--runsc/cmd/gofer.go2
-rw-r--r--runsc/container/console_test.go4
-rw-r--r--runsc/container/container.go39
-rw-r--r--runsc/container/container_test.go19
-rw-r--r--runsc/main.go37
-rw-r--r--runsc/sandbox/sandbox.go96
-rw-r--r--runsc/specutils/namespace.go3
-rw-r--r--runsc/testutil/testutil.go23
-rw-r--r--scripts/benchmark.sh20
-rwxr-xr-xscripts/build.sh2
-rwxr-xr-xscripts/common.sh14
-rwxr-xr-xscripts/dev.sh1
-rwxr-xr-xscripts/iptables_tests.sh11
-rwxr-xr-xscripts/packetimpact_tests.sh20
-rwxr-xr-xscripts/release.sh13
-rw-r--r--test/e2e/exec_test.go22
-rw-r--r--test/e2e/integration_test.go6
-rw-r--r--test/iptables/filter_input.go92
-rw-r--r--test/iptables/filter_output.go66
-rw-r--r--test/iptables/iptables_test.go104
-rw-r--r--test/iptables/iptables_util.go38
-rw-r--r--test/iptables/nat.go253
-rw-r--r--test/packetdrill/BUILD44
-rw-r--r--test/packetdrill/Dockerfile8
-rw-r--r--test/packetdrill/accept_ack_drop.pkt27
-rw-r--r--test/packetdrill/defs.bzl10
-rw-r--r--test/packetdrill/fin_wait2_timeout.pkt2
-rw-r--r--test/packetdrill/linux/tcp_user_timeout.pkt39
-rw-r--r--test/packetdrill/listen_close_before_handshake_complete.pkt31
-rw-r--r--test/packetdrill/netstack/tcp_user_timeout.pkt38
-rw-r--r--test/packetdrill/no_rst_to_rst.pkt36
-rwxr-xr-xtest/packetdrill/packetdrill_test.sh4
-rw-r--r--test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt9
-rw-r--r--test/packetdrill/sanity_test.pkt7
-rw-r--r--test/packetdrill/tcp_defer_accept.pkt48
-rw-r--r--test/packetdrill/tcp_defer_accept_timeout.pkt48
-rw-r--r--test/packetimpact/README.md531
-rw-r--r--test/packetimpact/dut/BUILD18
-rw-r--r--test/packetimpact/dut/posix_server.cc229
-rw-r--r--test/packetimpact/proto/BUILD12
-rw-r--r--test/packetimpact/proto/posix_server.proto150
-rw-r--r--test/packetimpact/testbench/BUILD31
-rw-r--r--test/packetimpact/testbench/connections.go245
-rw-r--r--test/packetimpact/testbench/dut.go363
-rw-r--r--test/packetimpact/testbench/dut_client.go28
-rw-r--r--test/packetimpact/testbench/layers.go507
-rw-r--r--test/packetimpact/testbench/rawsockets.go151
-rw-r--r--test/packetimpact/tests/BUILD21
-rw-r--r--test/packetimpact/tests/Dockerfile5
-rw-r--r--test/packetimpact/tests/defs.bzl106
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go68
-rwxr-xr-xtest/packetimpact/tests/test_runner.sh246
-rw-r--r--test/perf/BUILD2
-rw-r--r--test/perf/linux/futex_benchmark.cc144
-rw-r--r--test/perf/linux/signal_benchmark.cc10
-rw-r--r--test/root/BUILD3
-rw-r--r--test/root/runsc_test.go151
-rw-r--r--test/runner/gtest/gtest.go7
-rw-r--r--test/runner/runner.go27
-rw-r--r--test/syscalls/BUILD13
-rw-r--r--test/syscalls/linux/BUILD52
-rw-r--r--test/syscalls/linux/bad.cc2
-rw-r--r--test/syscalls/linux/dev.cc7
-rw-r--r--test/syscalls/linux/mkdir.cc2
-rw-r--r--test/syscalls/linux/network_namespace.cc87
-rw-r--r--test/syscalls/linux/open_create.cc2
-rw-r--r--test/syscalls/linux/packet_socket.cc138
-rw-r--r--test/syscalls/linux/proc.cc21
-rw-r--r--test/syscalls/linux/proc_net.cc78
-rw-r--r--test/syscalls/linux/proc_pid_oomscore.cc72
-rw-r--r--test/syscalls/linux/seccomp.cc42
-rw-r--r--test/syscalls/linux/socket_generic.cc96
-rw-r--r--test/syscalls/linux/stat.cc69
-rw-r--r--test/syscalls/linux/sticky.cc16
-rw-r--r--test/syscalls/linux/tcp_socket.cc29
-rw-r--r--test/syscalls/linux/time.cc2
-rw-r--r--test/syscalls/linux/tuntap.cc112
-rw-r--r--test/syscalls/linux/tuntap_hostinet.cc38
-rw-r--r--test/syscalls/linux/vsyscall.cc2
-rw-r--r--test/util/BUILD6
-rw-r--r--test/util/temp_umask.h (renamed from test/syscalls/linux/temp_umask.h)6
-rw-r--r--tools/BUILD2
-rw-r--r--tools/bazeldefs/defs.bzl88
-rw-r--r--tools/defs.bzl58
-rw-r--r--tools/go_marshal/gomarshal/BUILD3
-rw-r--r--tools/go_marshal/gomarshal/generator.go19
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go665
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go183
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go229
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go450
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go2
-rw-r--r--tools/go_marshal/gomarshal/util.go41
-rw-r--r--tools/go_marshal/test/test.go5
-rwxr-xr-xtools/go_mod.sh29
-rw-r--r--tools/images/BUILD2
-rwxr-xr-xtools/images/ubuntu1604/10_core.sh15
-rwxr-xr-xtools/images/ubuntu1604/20_bazel.sh12
-rwxr-xr-xtools/images/ubuntu1604/25_docker.sh33
-rwxr-xr-xtools/images/ubuntu1604/30_containerd.sh12
-rwxr-xr-xtools/images/ubuntu1604/40_kokoro.sh17
-rwxr-xr-xtools/installers/master.sh7
-rw-r--r--tools/nogo.js7
-rw-r--r--tools/nogo.json95
-rwxr-xr-xtools/tag_release.sh12
-rw-r--r--vdso/vdso.cc4
366 files changed, 14914 insertions, 5216 deletions
diff --git a/.travis.yml b/.travis.yml
index a2a260538..acbd3d61b 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -17,3 +17,7 @@ matrix:
script:
- uname -a
- make DOCKER_RUN_OPTIONS="" BAZEL_OPTIONS="build runsc:runsc" bazel && $RUNSC_PATH --alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do ls
+branches:
+ except:
+ # Skip copybara branches.
+ - /^test\/cl.*$/
diff --git a/BUILD b/BUILD
index 5fd929378..a709a9816 100644
--- a/BUILD
+++ b/BUILD
@@ -49,10 +49,31 @@ gazelle(name = "gazelle")
# live in the tools subdirectory (unless they are standard).
nogo(
name = "nogo",
- config = "//tools:nogo.js",
+ config = "//tools:nogo.json",
visibility = ["//visibility:public"],
deps = [
"//tools/checkunsafe",
+ "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/assign:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/atomic:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/atomicalign:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/bools:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/buildtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/cgocall:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/copylock:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/deepequalerrors:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/loopclosure:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/lostcancel:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilfunc:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilness:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/printf:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/shift:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/stdmethods:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/structtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/tests:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unmarshal:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unusedresult:go_tool_library",
],
)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 71650a4b8..ad8e710da 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -32,6 +32,9 @@ will need to be added to the appropriate `BUILD` files, and the `:gopath` target
will need to be re-run to generate appropriate symlinks in the `GOPATH`
directory tree.
+Dependencies can be added by using `go mod get`. In order to keep the
+`WORKSPACE` file in sync, run `tools/go_mod.sh` in place of `go mod`.
+
### Coding Guidelines
All Go code should conform to the [Go style guidelines][gostyle]. C++ code
diff --git a/Dockerfile b/Dockerfile
index 2bfdfec6c..0fac71710 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@ FROM fedora:31
RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
-RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static
+RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch
RUN pip install pycparser
diff --git a/WORKSPACE b/WORKSPACE
index a15238a2e..62dfb9dc6 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -4,10 +4,10 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
# Load go bazel rules and gazelle.
http_archive(
name = "io_bazel_rules_go",
- sha256 = "f99a9d76e972e0c8f935b2fe6d0d9d778f67c760c6d2400e23fc2e469016e2bd",
+ sha256 = "94f90feaa65c9cdc840cd21f67d967870b5943d684966a47569da8073e42063d",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz",
],
)
@@ -20,12 +20,12 @@ http_archive(
],
)
-load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_toolchains")
+load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies")
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.13.7",
+ go_version = "1.14",
nogo = "@//:nogo",
)
@@ -43,8 +43,8 @@ gazelle_dependencies()
go_repository(
name = "org_golang_x_sys",
importpath = "golang.org/x/sys",
- sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=",
- version = "v0.0.0-20191220220014-0732a990476f",
+ sum = "h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=",
+ version = "v0.0.0-20200302150141-5c8b2ff67527",
)
# Load C++ rules.
@@ -68,16 +68,19 @@ http_archive(
"https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
],
)
+
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
+
rules_proto_dependencies()
+
rules_proto_toolchains()
# Load python dependencies.
git_repository(
name = "rules_python",
- commit = "94677401bc56ed5d756f50b441a6a5c7f735a6d4",
+ commit = "abc4869e02fe9b3866942e89f07b7341f830e805",
remote = "https://github.com/bazelbuild/rules_python.git",
- shallow_since = "1573842889 -0500",
+ shallow_since = "1583341286 -0500",
)
load("@rules_python//python:pip.bzl", "pip_import")
@@ -96,11 +99,11 @@ pip_install()
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "a653c9d318e42b14c0ccd7ac50c4a2a276c0db1e39743ab88b5aa2f0bc9cf607",
- strip_prefix = "bazel-toolchains-2.0.2",
+ sha256 = "b5a8039df7119d618402472f3adff8a1bd0ae9d5e253f53fcc4c47122e91a3d2",
+ strip_prefix = "bazel-toolchains-2.1.1",
urls = [
- "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.0.2/bazel-toolchains-2.0.2.tar.gz",
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.0.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.1.1/bazel-toolchains-2.1.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.1.1.tar.gz",
],
)
@@ -146,9 +149,9 @@ load(
# This container is built from the Dockerfile in test/iptables/runner.
container_pull(
name = "iptables-test",
+ digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65",
registry = "gcr.io",
repository = "gvisor-presubmit/iptables-test",
- digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65",
)
load(
@@ -158,6 +161,20 @@ load(
_go_image_repos()
+# Load C++ grpc rules.
+http_archive(
+ name = "com_github_grpc_grpc",
+ sha256 = "2fcb7f1ab160d6fd3aaade64520be3e5446fc4c6fa7ba6581afdc4e26094bd81",
+ strip_prefix = "grpc-1.26.0",
+ urls = [
+ "https://github.com/grpc/grpc/archive/v1.26.0.tar.gz",
+ ],
+)
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+grpc_deps()
+load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
+grpc_extra_deps()
+
# External repositories, in sorted order.
go_repository(
name = "com_github_cenkalti_backoff",
@@ -202,6 +219,20 @@ go_repository(
)
go_repository(
+ name = "com_github_imdario_mergo",
+ importpath = "github.com/imdario/mergo",
+ version = "v0.3.8",
+ sum = "h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ=",
+)
+
+go_repository(
+ name = "com_github_kr_pretty",
+ importpath = "github.com/kr/pretty",
+ sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=",
+ version = "v0.2.0",
+)
+
+go_repository(
name = "com_github_kr_pty",
importpath = "github.com/kr/pty",
sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=",
@@ -209,6 +240,19 @@ go_repository(
)
go_repository(
+ name = "com_github_kr_text",
+ importpath = "github.com/kr/text",
+ sum = "h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_mohae_deepcopy",
+ importpath = "github.com/mohae/deepcopy",
+ commit = "c48cc78d482608239f6c4c92a4abd87eb8761c90",
+)
+
+go_repository(
name = "com_github_opencontainers_runtime-spec",
importpath = "github.com/opencontainers/runtime-spec",
sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=",
@@ -237,30 +281,39 @@ go_repository(
)
go_repository(
- name = "org_golang_x_crypto",
- importpath = "golang.org/x/crypto",
- sum = "h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=",
- version = "v0.0.0-20190308221718-c2843e01d9a2",
+ name = "org_golang_google_grpc",
+ build_file_proto_mode = "disable",
+ importpath = "google.golang.org/grpc",
+ sum = "h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=",
+ version = "v1.27.1",
)
go_repository(
- name = "org_golang_x_net",
- importpath = "golang.org/x/net",
- sum = "h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=",
- version = "v0.0.0-20190311183353-d8887717615a",
+ name = "in_gopkg_check_v1",
+ importpath = "gopkg.in/check.v1",
+ sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=",
+ version = "v1.0.0-20190902080502-41f04d3bba15",
)
go_repository(
- name = "org_golang_x_text",
- importpath = "golang.org/x/text",
- sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=",
- version = "v0.3.0",
+ name = "org_golang_x_crypto",
+ importpath = "golang.org/x/crypto",
+ sum = "h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=",
+ version = "v0.0.0-20191011191535-87dc89f01550",
)
go_repository(
- name = "org_golang_x_tools",
- commit = "36563e24a262",
- importpath = "golang.org/x/tools",
+ name = "org_golang_x_mod",
+ importpath = "golang.org/x/mod",
+ sum = "h1:p1YOIz9H/mGN8k1XkaV5VFAq9+zhN9Obefv439UwRhI=",
+ version = "v0.2.1-0.20200224194123-e5e73c1b9c72",
+)
+
+go_repository(
+ name = "org_golang_x_net",
+ importpath = "golang.org/x/net",
+ sum = "h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=",
+ version = "v0.0.0-20190620200207-3b0461eec859",
)
go_repository(
@@ -271,15 +324,31 @@ go_repository(
)
go_repository(
+ name = "org_golang_x_text",
+ importpath = "golang.org/x/text",
+ sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=",
+ version = "v0.3.0",
+)
+
+go_repository(
name = "org_golang_x_time",
- commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
+ sum = "h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=",
+ version = "v0.0.0-20191024005414-555d28b269f0",
)
go_repository(
name = "org_golang_x_tools",
- commit = "aa82965741a9fecd12b026fbb3d3c6ed3231b8f8",
importpath = "golang.org/x/tools",
+ sum = "h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ=",
+ version = "v0.0.0-20191119224855-298f0cb1881e",
+)
+
+go_repository(
+ name = "org_golang_x_xerrors",
+ importpath = "golang.org/x/xerrors",
+ sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=",
+ version = "v0.0.0-20191204190536-9bdfabe68543",
)
go_repository(
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
index 513d16e4f..1a624df2e 100644
--- a/benchmarks/harness/machine_producers/gcloud_producer.py
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -168,7 +168,9 @@ class GCloudProducer(machine_producer.MachineProducer):
cmd.append("--zone=" + self.zone)
cmd.append("--machine-type=" + self.machine_type)
res = self._run_command(cmd)
- return json.loads(res.stdout)
+ data = res.stdout
+ data = str(data, "utf-8") if isinstance(data, (bytes, bytearray)) else data
+ return json.loads(data)
def _add_ssh_key_to_instances(self, names: List[str]) -> None:
"""Adds ssh key to instances by calling gcloud ssh command.
@@ -186,11 +188,11 @@ class GCloudProducer(machine_producer.MachineProducer):
TimeoutError: when 3 unsuccessful tries to ssh into the host return 255.
"""
for name in names:
- cmd = "gcloud compute ssh {name}".format(name=name).split(" ")
+ cmd = "gcloud compute ssh {user}@{name}".format(
+ user=self.ssh_user, name=name).split(" ")
cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file))
cmd.append("--zone={zone}".format(zone=self.zone))
cmd.append("--command=uname")
- cmd.append("--ssh-key-expire-after=60m")
timeout = datetime.timedelta(seconds=5 * 60)
start = datetime.datetime.now()
while datetime.datetime.now() <= timeout + start:
diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock
index b44817bd3..ea9f0ea85 100644
--- a/benchmarks/workloads/ruby/Gemfile.lock
+++ b/benchmarks/workloads/ruby/Gemfile.lock
@@ -1,28 +1,41 @@
GEM
remote: https://rubygems.org/
specs:
+ activemerchant (1.105.0)
+ activesupport (>= 4.2)
+ builder (>= 2.1.2, < 4.0.0)
+ i18n (>= 0.6.9)
+ nokogiri (~> 1.4)
activesupport (5.2.3)
concurrent-ruby (~> 1.0, >= 1.0.2)
i18n (>= 0.7, < 2)
minitest (~> 5.1)
tzinfo (~> 1.1)
+ bcrypt (3.1.13)
+ builder (3.2.4)
cassandra-driver (3.2.3)
ione (~> 1.2)
concurrent-ruby (1.1.5)
+ ffi (1.12.2)
i18n (1.6.0)
concurrent-ruby (~> 1.0)
ione (1.2.4)
+ mini_portile2 (2.4.0)
minitest (5.11.3)
mustermann (1.0.3)
+ nokogiri (1.10.8)
+ mini_portile2 (~> 2.4.0)
pdf-core (0.7.0)
prawn (2.2.2)
pdf-core (~> 0.7.0)
ttfunk (~> 1.5)
- puma (3.12.1)
- rack (2.0.7)
+ puma (3.12.4)
+ rack (2.2.2)
rack-protection (2.0.5)
rack
- rake (12.3.2)
+ rake (12.3.3)
+ rbnacl (7.1.1)
+ ffi
redis (4.1.1)
ruby-fann (1.2.6)
sinatra (2.0.5)
@@ -43,9 +56,12 @@ PLATFORMS
ruby
DEPENDENCIES
+ activemerchant
+ bcrypt
cassandra-driver
puma
rake
+ rbnacl
redis
ruby-fann
sinatra
diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock
index dd8d56fb7..f637b6081 100644
--- a/benchmarks/workloads/ruby_template/Gemfile.lock
+++ b/benchmarks/workloads/ruby_template/Gemfile.lock
@@ -2,25 +2,25 @@ GEM
remote: https://rubygems.org/
specs:
mustermann (1.0.3)
- puma (3.12.0)
+ puma (3.12.4)
rack (2.0.6)
rack-protection (2.0.5)
rack
+ redis (4.1.0)
sinatra (2.0.5)
mustermann (~> 1.0)
rack (~> 2.0)
rack-protection (= 2.0.5)
tilt (~> 2.0)
tilt (2.0.9)
- redis (4.1.0)
PLATFORMS
ruby
DEPENDENCIES
puma
- sinatra
redis
+ sinatra
BUNDLED WITH
1.17.1 \ No newline at end of file
diff --git a/go.mod b/go.mod
index c4687ed02..434fa713f 100644
--- a/go.mod
+++ b/go.mod
@@ -1,23 +1,20 @@
module gvisor.dev/gvisor
-go 1.13
+go 1.14
require (
- github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422
- github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079
- github.com/golang/mock v1.3.1
- github.com/golang/protobuf v1.3.1
- github.com/google/btree v1.0.0
- github.com/google/go-cmp v0.2.0
- github.com/google/go-github/v28 v28.1.1
- github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
- github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d
- github.com/kr/pty v1.1.1
- github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78
- github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2
- github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
- github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936
- golang.org/x/net v0.0.0-20190311183353-d8887717615a
- golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6
- golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
+ github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422
+ github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079
+ github.com/golang/protobuf v1.3.1
+ github.com/google/btree v1.0.0
+ github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
+ github.com/kr/pretty v0.2.0 // indirect
+ github.com/kr/pty v1.1.1
+ github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78
+ github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2
+ github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
+ github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 // indirect
+ golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
+ golang.org/x/time v0.0.0-20191024005414-555d28b269f0
+ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
)
diff --git a/go.sum b/go.sum
index 434770beb..c44a17c71 100644
--- a/go.sum
+++ b/go.sum
@@ -1,21 +1,32 @@
+github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=
github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
+github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
-github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
+github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
-github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
-github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM=
+github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=
github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
-github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
+github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=
github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8=
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
+github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=
github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
+github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=
github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/kokoro/benchmark_tests.cfg b/kokoro/benchmark_tests.cfg
index c48518a05..f85cc9681 100644
--- a/kokoro/benchmark_tests.cfg
+++ b/kokoro/benchmark_tests.cfg
@@ -5,14 +5,14 @@ before_action {
fetch_keystore {
keystore_resource {
keystore_config_id : 73898
- keyname : 'kokoro-rbe-service-account'
+ keyname : 'gvisor-benchmarks-service-account'
},
}
}
env_vars {
key : 'PROJECT'
- value : 'gvisor-kokoro-testing'
+ value : 'gvisor-benchmarks'
}
env_vars {
@@ -21,6 +21,6 @@ env_vars {
}
env_vars {
- key : 'KOKORO_SERVICE_ACCOUNT'
- value : '73898_kokoro-rbe-service-account'
+ key : 'GCLOUD_CREDENTIALS'
+ value : '73898_gvisor-benchmarks-service-account'
}
diff --git a/kokoro/iptables_tests.cfg b/kokoro/iptables_tests.cfg
index 7af20629a..a30d82591 100644
--- a/kokoro/iptables_tests.cfg
+++ b/kokoro/iptables_tests.cfg
@@ -1,4 +1,4 @@
-build_file: "repo/scripts/iptables_test.sh"
+build_file: "repo/scripts/iptables_tests.sh"
action {
define_artifacts {
diff --git a/kokoro/packetimpact_tests.cfg b/kokoro/packetimpact_tests.cfg
new file mode 100644
index 000000000..db86b52d5
--- /dev/null
+++ b/kokoro/packetimpact_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/packetimpact_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go
index 40f0459a0..7c9a02f20 100644
--- a/pkg/abi/linux/elf.go
+++ b/pkg/abi/linux/elf.go
@@ -102,4 +102,7 @@ const (
// NT_X86_XSTATE is for x86 extended state using xsave.
NT_X86_XSTATE = 0x202
+
+ // NT_ARM_TLS is for ARM TLS register.
+ NT_ARM_TLS = 0x401
)
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index e229ac21c..055ac1d7c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -266,6 +266,9 @@ type Statx struct {
DevMinor uint32
}
+// SizeOfStatx is the size of a Statx struct.
+var SizeOfStatx = binary.Size(Statx{})
+
// FileMode represents a mode_t.
type FileMode uint16
@@ -284,6 +287,11 @@ func (m FileMode) ExtraBits() FileMode {
return m &^ (PermissionsMask | FileTypeMask)
}
+// IsDir returns true if file type represents a directory.
+func (m FileMode) IsDir() bool {
+ return m.FileType() == S_IFDIR
+}
+
// String returns a string representation of m.
func (m FileMode) String() string {
var s []string
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index bd2e13ba1..80dc09aa9 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -158,10 +158,32 @@ type IPTIP struct {
// Flags define matching behavior for the IP header.
Flags uint8
- // InverseFlags invert the meaning of fields in struct IPTIP.
+ // InverseFlags invert the meaning of fields in struct IPTIP. See the
+ // IPT_INV_* flags.
InverseFlags uint8
}
+// Flags in IPTIP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+const (
+ // Invert the meaning of InputInterface.
+ IPT_INV_VIA_IN = 0x01
+ // Invert the meaning of OutputInterface.
+ IPT_INV_VIA_OUT = 0x02
+ // Unclear what this is, as no references to it exist in the kernel.
+ IPT_INV_TOS = 0x04
+ // Invert the meaning of Src.
+ IPT_INV_SRCIP = 0x08
+ // Invert the meaning of Dst.
+ IPT_INV_DSTIP = 0x10
+ // Invert the meaning of the IPT_F_FRAG flag.
+ IPT_INV_FRAG = 0x20
+ // Invert the meaning of the Protocol field.
+ IPT_INV_PROTO = 0x40
+ // Enable all flags.
+ IPT_INV_MASK = 0x7F
+)
+
// SizeOfIPTIP is the size of an IPTIP.
const SizeOfIPTIP = 84
@@ -253,6 +275,50 @@ type XTErrorTarget struct {
// SizeOfXTErrorTarget is the size of an XTErrorTarget.
const SizeOfXTErrorTarget = 64
+// Flag values for NfNATIPV4Range. The values indicate whether to map
+// protocol specific part(ports) or IPs. It corresponds to values in
+// include/uapi/linux/netfilter/nf_nat.h.
+const (
+ NF_NAT_RANGE_MAP_IPS = 1 << 0
+ NF_NAT_RANGE_PROTO_SPECIFIED = 1 << 1
+ NF_NAT_RANGE_PROTO_RANDOM = 1 << 2
+ NF_NAT_RANGE_PERSISTENT = 1 << 3
+ NF_NAT_RANGE_PROTO_RANDOM_FULLY = 1 << 4
+ NF_NAT_RANGE_PROTO_RANDOM_ALL = (NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+ NF_NAT_RANGE_MASK = (NF_NAT_RANGE_MAP_IPS |
+ NF_NAT_RANGE_PROTO_SPECIFIED | NF_NAT_RANGE_PROTO_RANDOM |
+ NF_NAT_RANGE_PERSISTENT | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+)
+
+// NfNATIPV4Range corresponds to struct nf_nat_ipv4_range
+// in include/uapi/linux/netfilter/nf_nat.h. The fields are in
+// network byte order.
+type NfNATIPV4Range struct {
+ Flags uint32
+ MinIP [4]byte
+ MaxIP [4]byte
+ MinPort uint16
+ MaxPort uint16
+}
+
+// NfNATIPV4MultiRangeCompat corresponds to struct
+// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h.
+type NfNATIPV4MultiRangeCompat struct {
+ RangeSize uint32
+ RangeIPV4 NfNATIPV4Range
+}
+
+// XTRedirectTarget triggers a redirect when reached.
+// Adding 4 bytes of padding to make the struct 8 byte aligned.
+type XTRedirectTarget struct {
+ Target XTEntryTarget
+ NfRange NfNATIPV4MultiRangeCompat
+ _ [4]byte
+}
+
+// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
+const SizeOfXTRedirectTarget = 56
+
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTGetinfo struct {
diff --git a/pkg/bits/bits_template.go b/pkg/bits/bits_template.go
index 93a435b80..998645388 100644
--- a/pkg/bits/bits_template.go
+++ b/pkg/bits/bits_template.go
@@ -42,3 +42,11 @@ func Mask(is ...int) T {
func MaskOf(i int) T {
return T(1) << T(i)
}
+
+// IsPowerOfTwo returns true if v is power of 2.
+func IsPowerOfTwo(v T) bool {
+ if v == 0 {
+ return false
+ }
+ return v&(v-1) == 0
+}
diff --git a/pkg/bits/uint64_test.go b/pkg/bits/uint64_test.go
index 1b018d808..193d1ebcd 100644
--- a/pkg/bits/uint64_test.go
+++ b/pkg/bits/uint64_test.go
@@ -114,3 +114,21 @@ func TestIsOn(t *testing.T) {
}
}
}
+
+func TestIsPowerOfTwo(t *testing.T) {
+ for _, tc := range []struct {
+ v uint64
+ want bool
+ }{
+ {v: 0, want: false},
+ {v: 1, want: true},
+ {v: 2, want: true},
+ {v: 3, want: false},
+ {v: 4, want: true},
+ {v: 5, want: false},
+ } {
+ if got := IsPowerOfTwo64(tc.v); got != tc.want {
+ t.Errorf("IsPowerOfTwo(%d) = %t, want: %t", tc.v, got, tc.want)
+ }
+ }
+}
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD
new file mode 100644
index 000000000..dcd086298
--- /dev/null
+++ b/pkg/buffer/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "buffer_list",
+ out = "buffer_list.go",
+ package = "buffer",
+ prefix = "buffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*buffer",
+ "Linker": "*buffer",
+ },
+)
+
+go_library(
+ name = "buffer",
+ srcs = [
+ "buffer.go",
+ "buffer_list.go",
+ "safemem.go",
+ "view.go",
+ "view_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/safemem",
+ ],
+)
+
+go_test(
+ name = "buffer_test",
+ size = "small",
+ srcs = [
+ "safemem_test.go",
+ "view_test.go",
+ ],
+ library = ":buffer",
+ deps = ["//pkg/safemem"],
+)
diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go
new file mode 100644
index 000000000..c6d089fd9
--- /dev/null
+++ b/pkg/buffer/buffer.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+//
+// A view is an flexible buffer, backed by a pool, supporting the safecopy
+// operations natively as well as the ability to grow via either prepend or
+// append, as well as shrink.
+package buffer
+
+import (
+ "sync"
+)
+
+const bufferSize = 8144 // See below.
+
+// buffer encapsulates a queueable byte buffer.
+//
+// Note that the total size is slightly less than two pages. This is done
+// intentionally to ensure that the buffer object aligns with runtime
+// internals. We have no hard size or alignment requirements. This two page
+// size will effectively minimize internal fragmentation, but still have a
+// large enough chunk to limit excessive segmentation.
+//
+// +stateify savable
+type buffer struct {
+ data [bufferSize]byte
+ read int
+ write int
+ bufferEntry
+}
+
+// reset resets internal data.
+//
+// This must be called before returning the buffer to the pool.
+func (b *buffer) Reset() {
+ b.read = 0
+ b.write = 0
+}
+
+// Full indicates the buffer is full.
+//
+// This indicates there is no capacity left to write.
+func (b *buffer) Full() bool {
+ return b.write == len(b.data)
+}
+
+// ReadSize returns the number of bytes available for reading.
+func (b *buffer) ReadSize() int {
+ return b.write - b.read
+}
+
+// ReadMove advances the read index by the given amount.
+func (b *buffer) ReadMove(n int) {
+ b.read += n
+}
+
+// ReadSlice returns the read slice for this buffer.
+func (b *buffer) ReadSlice() []byte {
+ return b.data[b.read:b.write]
+}
+
+// WriteSize returns the number of bytes available for writing.
+func (b *buffer) WriteSize() int {
+ return len(b.data) - b.write
+}
+
+// WriteMove advances the write index by the given amount.
+func (b *buffer) WriteMove(n int) {
+ b.write += n
+}
+
+// WriteSlice returns the write slice for this buffer.
+func (b *buffer) WriteSlice() []byte {
+ return b.data[b.write:]
+}
+
+// bufferPool is a pool for buffers.
+var bufferPool = sync.Pool{
+ New: func() interface{} {
+ return new(buffer)
+ },
+}
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
new file mode 100644
index 000000000..0e5b86344
--- /dev/null
+++ b/pkg/buffer/safemem.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// WriteBlock returns this buffer as a write Block.
+func (b *buffer) WriteBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.WriteSlice())
+}
+
+// ReadBlock returns this buffer as a read Block.
+func (b *buffer) ReadBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.ReadSlice())
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// This will advance the write index.
+func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ need := int(srcs.NumBytes())
+ if need == 0 {
+ return 0, nil
+ }
+
+ var (
+ dst safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ // Need at least one buffer.
+ firstBuf := v.data.Back()
+ if firstBuf == nil {
+ firstBuf = bufferPool.Get().(*buffer)
+ v.data.PushBack(firstBuf)
+ }
+
+ // Does the last block have sufficient capacity alone?
+ if l := firstBuf.WriteSize(); l >= need {
+ dst = safemem.BlockSeqOf(firstBuf.WriteBlock())
+ } else {
+ // Append blocks until sufficient.
+ need -= l
+ blocks = append(blocks, firstBuf.WriteBlock())
+ for need > 0 {
+ emptyBuf := bufferPool.Get().(*buffer)
+ v.data.PushBack(emptyBuf)
+ need -= emptyBuf.WriteSize()
+ blocks = append(blocks, emptyBuf.WriteBlock())
+ }
+ dst = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform the copy.
+ n, err := safemem.CopySeq(dst, srcs)
+ v.size += int64(n)
+
+ // Update all indices.
+ for left := int(n); left > 0; firstBuf = firstBuf.Next() {
+ if l := firstBuf.WriteSize(); left >= l {
+ firstBuf.WriteMove(l) // Whole block.
+ left -= l
+ } else {
+ firstBuf.WriteMove(left) // Partial block.
+ left = 0
+ }
+ }
+
+ return n, err
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+//
+// This will not advance the read index; the caller should follow
+// this call with a call to TrimFront in order to remove the read
+// data from the buffer. This is done to support pipe sematics.
+func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ need := int(dsts.NumBytes())
+ if need == 0 {
+ return 0, nil
+ }
+
+ var (
+ src safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ firstBuf := v.data.Front()
+ if firstBuf == nil {
+ return 0, nil // No EOF.
+ }
+
+ // Is all the data in a single block?
+ if l := firstBuf.ReadSize(); l >= need {
+ src = safemem.BlockSeqOf(firstBuf.ReadBlock())
+ } else {
+ // Build a list of all the buffers.
+ need -= l
+ blocks = append(blocks, firstBuf.ReadBlock())
+ for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() {
+ need -= buf.ReadSize()
+ blocks = append(blocks, buf.ReadBlock())
+ }
+ src = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform the copy.
+ n, err := safemem.CopySeq(dsts, src)
+
+ // See above: we would normally advance the read index here, but we
+ // don't do that in order to support pipe semantics. We rely on a
+ // separate call to TrimFront() in this case.
+
+ return n, err
+}
diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go
new file mode 100644
index 000000000..47f357e0c
--- /dev/null
+++ b/pkg/buffer/safemem_test.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+func TestSafemem(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ readLen int
+ op func(*View)
+ }{
+ // Basic coverage.
+ {
+ name: "short",
+ input: "010",
+ output: "010",
+ },
+ {
+ name: "long",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize) + "0",
+ },
+ {
+ name: "short-read",
+ input: "0",
+ readLen: 100, // > size.
+ output: "0",
+ },
+ {
+ name: "zero-read",
+ input: "0",
+ output: "",
+ },
+ {
+ name: "read-empty",
+ input: "",
+ readLen: 1, // > size.
+ output: "",
+ },
+
+ // Ensure offsets work.
+ {
+ name: "offsets-short",
+ input: "012",
+ output: "2",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "offsets-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "10",
+ op: func(v *View) {
+ v.TrimFront(bufferSize)
+ },
+ },
+
+ // Ensure truncation works.
+ {
+ name: "truncate-short",
+ input: "012",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ {
+ name: "truncate-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(v *View) {
+ v.Truncate(bufferSize + 1)
+ },
+ },
+ {
+ name: "truncate-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(v *View) {
+ v.Truncate(bufferSize)
+ },
+ },
+ {
+ name: "truncate-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Construct the new view.
+ var view View
+ bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input)))
+ n, err := view.WriteFromBlocks(bs)
+ if err != nil {
+ t.Errorf("expected err nil, got %v", err)
+ }
+ if n != uint64(len(tc.input)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.input), n)
+ }
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(&view)
+ }
+
+ // Read and validate.
+ readLen := tc.readLen
+ if readLen == 0 {
+ readLen = len(tc.output) // Default.
+ }
+ out := make([]byte, readLen)
+ bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out))
+ n, err = view.ReadToBlocks(bs)
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if n != uint64(len(tc.output)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.output), n)
+ }
+
+ // Ensure the contents are correct.
+ if !bytes.Equal(out[:n], []byte(tc.output[:n])) {
+ t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out))
+ }
+ })
+ }
+}
diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go
new file mode 100644
index 000000000..e6901eadb
--- /dev/null
+++ b/pkg/buffer/view.go
@@ -0,0 +1,390 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "fmt"
+ "io"
+)
+
+// View is a non-linear buffer.
+//
+// All methods are thread compatible.
+//
+// +stateify savable
+type View struct {
+ data bufferList
+ size int64
+}
+
+// TrimFront removes the first count bytes from the buffer.
+func (v *View) TrimFront(count int64) {
+ if count >= v.size {
+ v.advanceRead(v.size)
+ } else {
+ v.advanceRead(count)
+ }
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (v *View) ReadAt(p []byte, offset int64) (int, error) {
+ var (
+ skipped int64
+ done int64
+ )
+ for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() {
+ needToSkip := int(offset - skipped)
+ if sz := buf.ReadSize(); sz <= needToSkip {
+ skipped += int64(sz)
+ continue
+ }
+
+ // Actually read data.
+ n := copy(p[done:], buf.ReadSlice()[needToSkip:])
+ skipped += int64(needToSkip)
+ done += int64(n)
+ }
+ if int(done) < len(p) || offset+done == v.size {
+ return int(done), io.EOF
+ }
+ return int(done), nil
+}
+
+// advanceRead advances the view's read index.
+//
+// Precondition: there must be sufficient bytes in the buffer.
+func (v *View) advanceRead(count int64) {
+ for buf := v.data.Front(); buf != nil && count > 0; {
+ sz := int64(buf.ReadSize())
+ if sz > count {
+ // There is still data for reading.
+ buf.ReadMove(int(count))
+ v.size -= count
+ count = 0
+ break
+ }
+
+ // Consume the whole buffer.
+ oldBuf := buf
+ buf = buf.Next() // Iterate.
+ v.data.Remove(oldBuf)
+ oldBuf.Reset()
+ bufferPool.Put(oldBuf)
+
+ // Update counts.
+ count -= sz
+ v.size -= sz
+ }
+ if count > 0 {
+ panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count))
+ }
+}
+
+// Truncate truncates the view to the given bytes.
+//
+// This will not grow the view, only shrink it. If a length is passed that is
+// greater than the current size of the view, then nothing will happen.
+//
+// Precondition: length must be >= 0.
+func (v *View) Truncate(length int64) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ if length >= v.size {
+ return // Nothing to do.
+ }
+ for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() {
+ sz := int64(buf.ReadSize())
+ if after := v.size - sz; after < length {
+ // Truncate the buffer locally.
+ left := (length - after)
+ buf.write = buf.read + int(left)
+ v.size = length
+ break
+ }
+
+ // Drop the buffer completely; see above.
+ v.data.Remove(buf)
+ buf.Reset()
+ bufferPool.Put(buf)
+ v.size -= sz
+ }
+}
+
+// Grow grows the given view to the number of bytes, which will be appended. If
+// zero is true, all these bytes will be zero. If zero is false, then this is
+// the caller's responsibility.
+//
+// Precondition: length must be >= 0.
+func (v *View) Grow(length int64, zero bool) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ for v.size < length {
+ buf := v.data.Back()
+
+ // Is there some space in the last buffer?
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Write up to length bytes.
+ sz := buf.WriteSize()
+ if int64(sz) > length-v.size {
+ sz = int(length - v.size)
+ }
+
+ // Zero the written section; note that this pattern is
+ // specifically recognized and optimized by the compiler.
+ if zero {
+ for i := buf.write; i < buf.write+sz; i++ {
+ buf.data[i] = 0
+ }
+ }
+
+ // Advance the index.
+ buf.WriteMove(sz)
+ v.size += int64(sz)
+ }
+}
+
+// Prepend prepends the given data.
+func (v *View) Prepend(data []byte) {
+ // Is there any space in the first buffer?
+ if buf := v.data.Front(); buf != nil && buf.read > 0 {
+ // Fill up before the first write.
+ avail := buf.read
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read -= n
+ }
+
+ for len(data) > 0 {
+ // Do we need an empty buffer?
+ buf := bufferPool.Get().(*buffer)
+ v.data.PushFront(buf)
+
+ // The buffer is empty; copy last chunk.
+ avail := len(buf.data)
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+
+ // We have to put the data at the end of the current
+ // buffer in order to ensure that the next prepend will
+ // correctly fill up the beginning of this buffer.
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read = len(buf.data) - n
+ buf.write = len(buf.data)
+ }
+}
+
+// Append appends the given data.
+func (v *View) Append(data []byte) {
+ for done := 0; done < len(data); {
+ buf := v.data.Back()
+
+ // Ensure there's a buffer with space.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Copy in to the given buffer.
+ n := copy(buf.WriteSlice(), data[done:])
+ done += n
+ buf.WriteMove(n)
+ v.size += int64(n)
+ }
+}
+
+// Flatten returns a flattened copy of this data.
+//
+// This method should not be used in any performance-sensitive paths. It may
+// allocate a fresh byte slice sufficiently large to contain all the data in
+// the buffer. This is principally for debugging.
+//
+// N.B. Tee data still belongs to this view, as if there is a single buffer
+// present, then it will be returned directly. This should be used for
+// temporary use only, and a reference to the given slice should not be held.
+func (v *View) Flatten() []byte {
+ if buf := v.data.Front(); buf == nil {
+ return nil // No data at all.
+ } else if buf.Next() == nil {
+ return buf.ReadSlice() // Only one buffer.
+ }
+ data := make([]byte, 0, v.size) // Need to flatten.
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ // Copy to the allocated slice.
+ data = append(data, buf.ReadSlice()...)
+ }
+ return data
+}
+
+// Size indicates the total amount of data available in this view.
+func (v *View) Size() int64 {
+ return v.size
+}
+
+// Copy makes a strict copy of this view.
+func (v *View) Copy() (other View) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ other.Append(buf.ReadSlice())
+ }
+ return
+}
+
+// Apply applies the given function across all valid data.
+func (v *View) Apply(fn func([]byte)) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ fn(buf.ReadSlice())
+ }
+}
+
+// Merge merges the provided View with this one.
+//
+// The other view will be appended to v, and other will be empty after this
+// operation completes.
+func (v *View) Merge(other *View) {
+ // Copy over all buffers.
+ for buf := other.data.Front(); buf != nil; buf = other.data.Front() {
+ other.data.Remove(buf)
+ v.data.PushBack(buf)
+ }
+
+ // Adjust sizes.
+ v.size += other.size
+ other.size = 0
+}
+
+// WriteFromReader writes to the buffer from an io.Reader.
+//
+// A minimum read size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ for done < count {
+ buf := v.data.Back()
+
+ // Ensure we have an empty buffer.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Is this less than the minimum batch?
+ if buf.WriteSize() < minBatch && (count-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = r.Read(tmp)
+ v.Append(tmp[:n])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the read, if necessary.
+ sz := buf.WriteSize()
+ if left := count - done; int64(sz) > left {
+ sz = int(left)
+ }
+
+ // Pass the relevant portion of the buffer.
+ n, err = r.Read(buf.WriteSlice()[:sz])
+ buf.WriteMove(n)
+ done += int64(n)
+ v.size += int64(n)
+ if err == io.EOF {
+ err = nil // Short write allowed.
+ break
+ } else if err != nil {
+ break
+ }
+ }
+ return done, err
+}
+
+// ReadToWriter reads from the buffer into an io.Writer.
+//
+// N.B. This does not consume the bytes read. TrimFront should
+// be called appropriately after this call in order to do so.
+//
+// A minimum write size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ offset := 0 // Spill-over for batching.
+ for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() {
+ // Has this been consumed? Skip it.
+ sz := buf.ReadSize()
+ if sz <= offset {
+ offset -= sz
+ continue
+ }
+ sz -= offset
+
+ // Is this less than the minimum batch?
+ left := count - done
+ if sz < minBatch && left >= int64(minBatch) && (v.size-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = v.ReadAt(tmp, done)
+ w.Write(tmp[:n])
+ done += int64(n)
+ offset = n - sz // Reset below.
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the write if necessary.
+ if int64(sz) >= left {
+ sz = int(left)
+ }
+
+ // Perform the actual write.
+ n, err = w.Write(buf.ReadSlice()[offset : offset+sz])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+
+ // Reset spill-over.
+ offset = 0
+ }
+ return done, err
+}
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
new file mode 100644
index 000000000..3db1bc6ee
--- /dev/null
+++ b/pkg/buffer/view_test.go
@@ -0,0 +1,467 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "io"
+ "strings"
+ "testing"
+)
+
+func fillAppend(v *View, data []byte) {
+ v.Append(data)
+}
+
+func fillAppendEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ v.Append(data)
+ v.TrimFront(bufferSize - 1)
+}
+
+func fillWriteFromReader(v *View, data []byte) {
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+}
+
+func fillWriteFromReaderEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+ v.TrimFront(bufferSize - 1)
+}
+
+var fillFuncs = map[string]func(*View, []byte){
+ "append": fillAppend,
+ "appendEnd": fillAppendEnd,
+ "writeFromReader": fillWriteFromReader,
+ "writeFromReaderEnd": fillWriteFromReaderEnd,
+}
+
+func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) {
+ t.Helper()
+ d := make([]byte, n)
+ n, err := v.ReadAt(d, offset)
+ if n != len(wantStr) {
+ t.Errorf("got %d, want %d", n, len(wantStr))
+ }
+ if err != wantErr {
+ t.Errorf("got err %v, want %v", err, wantErr)
+ }
+ if !bytes.Equal(d[:n], []byte(wantStr)) {
+ t.Errorf("got %q, want %q", string(d[:n]), wantStr)
+ }
+}
+
+func TestView(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ op func(*testing.T, *View)
+ }{
+ // Preconditions.
+ {
+ name: "truncate-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Truncate(-1) did not panic")
+ }
+ }()
+ v.Truncate(-1)
+ },
+ },
+ {
+ name: "grow-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Grow(-1) did not panic")
+ }
+ }()
+ v.Grow(-1, false)
+ },
+ },
+ {
+ name: "advance-check",
+ input: "hello",
+ output: "", // Consumed.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("advanceRead(Size()+1) did not panic")
+ }
+ }()
+ v.advanceRead(v.Size() + 1)
+ },
+ },
+
+ // Prepend.
+ {
+ name: "prepend",
+ input: "world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("hello "))
+ },
+ },
+ {
+ name: "prepend-backfill-full",
+ input: "hello world",
+ output: "jello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("j"))
+ },
+ },
+ {
+ name: "prepend-backfill-under",
+ input: "hello world",
+ output: "hola world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(5)
+ v.Prepend([]byte("hola"))
+ },
+ },
+ {
+ name: "prepend-backfill-over",
+ input: "hello world",
+ output: "smello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("sm"))
+ },
+ },
+ {
+ name: "prepend-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Append and write.
+ {
+ name: "append",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(" world"))
+ },
+ },
+ {
+ name: "append-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3),
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Truncate.
+ {
+ name: "truncate",
+ input: "hello world",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+ {
+ name: "truncate-noop",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(v.Size() + 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.Truncate(bufferSize*2 - 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers-to-one",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "11111",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+
+ // TrimFront.
+ {
+ name: "trim",
+ input: "hello world",
+ output: "world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(6)
+ },
+ },
+ {
+ name: "trim-too-large",
+ input: "hello world",
+ output: "",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(v.Size() + 1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers-to-one-buffer",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "1",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(bufferSize*2 - 1)
+ },
+ },
+
+ // Grow.
+ {
+ name: "grow",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Grow(1, true)
+ },
+ },
+ {
+ name: "grow-from-zero",
+ output: strings.Repeat("\x00", 1024),
+ op: func(t *testing.T, v *View) {
+ v.Grow(1024, true)
+ },
+ },
+ {
+ name: "grow-from-non-zero",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Grow(bufferSize*2, true)
+ },
+ },
+
+ // Copy.
+ {
+ name: "copy",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte("hello")
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+ {
+ name: "copy-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte(strings.Repeat("1", bufferSize+1))
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+
+ // Merge.
+ {
+ name: "merge",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(" world"))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+ {
+ name: "merge-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1) + strings.Repeat("0", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(strings.Repeat("0", bufferSize+1)))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+
+ // ReadAt.
+ {
+ name: "readat",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 6, "hello", io.EOF) },
+ },
+ {
+ name: "readat-long",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 8, "hello", io.EOF) },
+ },
+ {
+ name: "readat-short",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 3, "hel", nil) },
+ },
+ {
+ name: "readat-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 3, "llo", io.EOF) },
+ },
+ {
+ name: "readat-long-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 8, "llo", io.EOF) },
+ },
+ {
+ name: "readat-short-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 2, "ll", nil) },
+ },
+ {
+ name: "readat-skip-all",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "", io.EOF) },
+ },
+ {
+ name: "readat-second-buffer",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "1", nil) },
+ },
+ {
+ name: "readat-second-buffer-end",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 2, "12", io.EOF) },
+ },
+ }
+
+ for _, tc := range testCases {
+ for fillName, fn := range fillFuncs {
+ t.Run(fillName+"/"+tc.name, func(t *testing.T) {
+ // Construct & fill the view.
+ var view View
+ fn(&view, []byte(tc.input))
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(t, &view)
+ }
+
+ // Flatten and validate.
+ out := view.Flatten()
+ if !bytes.Equal([]byte(tc.output), out) {
+ t.Errorf("expected %q, got %q", tc.output, string(out))
+ }
+
+ // Ensure the size is correct.
+ if len(out) != int(view.Size()) {
+ t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size())
+ }
+
+ // Calculate contents via apply.
+ var appliedOut []byte
+ view.Apply(func(b []byte) {
+ appliedOut = append(appliedOut, b...)
+ })
+ if len(appliedOut) != len(out) {
+ t.Errorf("expected %d, got %d", len(out), len(appliedOut))
+ }
+ if !bytes.Equal(appliedOut, out) {
+ t.Errorf("expected %v, got %v", out, appliedOut)
+ }
+
+ // Calculate contents via ReadToWriter.
+ var b bytes.Buffer
+ n, err := view.ReadToWriter(&b, int64(len(out)))
+ if n != int64(len(out)) {
+ t.Errorf("expected %d, got %d", len(out), n)
+ }
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if !bytes.Equal(b.Bytes(), out) {
+ t.Errorf("expected %v, got %v", out, b.Bytes())
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/buffer/view_unsafe.go b/pkg/buffer/view_unsafe.go
new file mode 100644
index 000000000..d1ef39b26
--- /dev/null
+++ b/pkg/buffer/view_unsafe.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "unsafe"
+)
+
+// minBatch is the smallest Read or Write operation that the
+// WriteFromReader and ReadToWriter functions will use.
+//
+// This is defined as the size of a native pointer.
+const minBatch = int(unsafe.Sizeof(uintptr(0)))
diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go
index 850693df8..316015e06 100644
--- a/pkg/gate/gate_test.go
+++ b/pkg/gate/gate_test.go
@@ -15,6 +15,7 @@
package gate_test
import (
+ "runtime"
"testing"
"time"
@@ -165,6 +166,8 @@ func worker(g *gate.Gate, done *sync.WaitGroup) {
if !g.Enter() {
break
}
+ // Golang before v1.14 doesn't preempt busyloops.
+ runtime.Gosched()
g.Leave()
}
done.Done()
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
index 019caadca..8f93e4d6d 100644
--- a/pkg/ilist/list.go
+++ b/pkg/ilist/list.go
@@ -88,8 +88,9 @@ func (l *List) Back() Element {
// PushFront inserts the element e at the front of list l.
func (l *List) PushFront(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(l.head)
- ElementMapper{}.linkerFor(e).SetPrev(nil)
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
if l.head != nil {
ElementMapper{}.linkerFor(l.head).SetPrev(e)
@@ -102,8 +103,9 @@ func (l *List) PushFront(e Element) {
// PushBack inserts the element e at the back of list l.
func (l *List) PushBack(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(nil)
- ElementMapper{}.linkerFor(e).SetPrev(l.tail)
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
if l.tail != nil {
ElementMapper{}.linkerFor(l.tail).SetNext(e)
@@ -132,10 +134,14 @@ func (l *List) PushBackList(m *List) {
// InsertAfter inserts e after b.
func (l *List) InsertAfter(b, e Element) {
- a := ElementMapper{}.linkerFor(b).Next()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(b).SetNext(e)
+ bLinker := ElementMapper{}.linkerFor(b)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
if a != nil {
ElementMapper{}.linkerFor(a).SetPrev(e)
@@ -146,10 +152,13 @@ func (l *List) InsertAfter(b, e Element) {
// InsertBefore inserts e before a.
func (l *List) InsertBefore(a, e Element) {
- b := ElementMapper{}.linkerFor(a).Prev()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(a).SetPrev(e)
+ aLinker := ElementMapper{}.linkerFor(a)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
if b != nil {
ElementMapper{}.linkerFor(b).SetNext(e)
@@ -160,8 +169,9 @@ func (l *List) InsertBefore(a, e Element) {
// Remove removes e from l.
func (l *List) Remove(e Element) {
- prev := ElementMapper{}.linkerFor(e).Prev()
- next := ElementMapper{}.linkerFor(e).Next()
+ linker := ElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
if prev != nil {
ElementMapper{}.linkerFor(prev).SetNext(next)
@@ -174,6 +184,9 @@ func (l *List) Remove(e Element) {
} else {
l.tail = prev
}
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go
index 0bdad5fad..fa6a21026 100644
--- a/pkg/rand/rand_linux.go
+++ b/pkg/rand/rand_linux.go
@@ -17,6 +17,7 @@
package rand
import (
+ "bufio"
"crypto/rand"
"io"
@@ -45,8 +46,22 @@ func (r *reader) Read(p []byte) (int, error) {
return rand.Read(p)
}
+// bufferedReader implements a threadsafe buffered io.Reader.
+type bufferedReader struct {
+ mu sync.Mutex
+ r *bufio.Reader
+}
+
+// Read implements io.Reader.Read.
+func (b *bufferedReader) Read(p []byte) (int, error) {
+ b.mu.Lock()
+ n, err := b.r.Read(p)
+ b.mu.Unlock()
+ return n, err
+}
+
// Reader is the default reader.
-var Reader io.Reader = &reader{}
+var Reader io.Reader = &bufferedReader{r: bufio.NewReader(&reader{})}
// Read reads from the default reader.
func Read(b []byte) (int, error) {
diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go
index eba4bb535..de34005e9 100644
--- a/pkg/safemem/seq_test.go
+++ b/pkg/safemem/seq_test.go
@@ -20,6 +20,27 @@ import (
"testing"
)
+func TestBlockSeqOfEmptyBlock(t *testing.T) {
+ bs := BlockSeqOf(Block{})
+ if !bs.IsEmpty() {
+ t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs)
+ }
+}
+
+func TestBlockSeqOfNonemptyBlock(t *testing.T) {
+ b := BlockFromSafeSlice(make([]byte, 1))
+ bs := BlockSeqOf(b)
+ if bs.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs)
+ }
+ if head := bs.Head(); head != b {
+ t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b)
+ }
+ if tail := bs.Tail(); !tail.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail)
+ }
+}
+
type blockSeqTest struct {
desc string
diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go
index dcdfc9600..f5f0574f8 100644
--- a/pkg/safemem/seq_unsafe.go
+++ b/pkg/safemem/seq_unsafe.go
@@ -56,6 +56,9 @@ type BlockSeq struct {
// BlockSeqOf returns a BlockSeq representing the single Block b.
func BlockSeqOf(b Block) BlockSeq {
+ if b.length == 0 {
+ return BlockSeq{}
+ }
bs := BlockSeq{
data: b.start,
length: -1,
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index da5a5e4b2..88766f33b 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -451,7 +451,7 @@ func TestRandom(t *testing.T) {
}
}
- fmt.Printf("Testing filters: %v", syscallRules)
+ t.Logf("Testing filters: %v", syscallRules)
instrs, err := BuildProgram([]RuleSet{
RuleSet{
Rules: syscallRules,
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index d7794db63..c29e1b841 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -32,6 +32,12 @@ import (
const (
// SyscallWidth is the width of insturctions.
SyscallWidth = 4
+
+ // fpsimdMagic is the magic number which is used in fpsimd_context.
+ fpsimdMagic = 0x46508001
+
+ // fpsimdContextSize is the size of fpsimd_context.
+ fpsimdContextSize = 0x210
)
// ARMTrapFlag is the mask for the trap flag.
@@ -40,24 +46,24 @@ const ARMTrapFlag = uint64(1) << 21
// aarch64FPState is aarch64 floating point state.
type aarch64FPState []byte
-// initAarch64FPState (defined in asm files) sets up initial state.
-func initAarch64FPState(data *FloatingPointData) {
- // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+// initAarch64FPState sets up initial state.
+func initAarch64FPState(data aarch64FPState) {
+ binary.LittleEndian.PutUint32(data, fpsimdMagic)
+ binary.LittleEndian.PutUint32(data[4:], fpsimdContextSize)
}
func newAarch64FPStateSlice() []byte {
- return alignedBytes(4096, 32)[:4096]
+ return alignedBytes(4096, 16)[:fpsimdContextSize]
}
// newAarch64FPState returns an initialized floating point state.
//
// The returned state is large enough to store all floating point state
// supported by host, even if the app won't use much of it due to a restricted
-// FeatureSet. Since they may still be able to see state not advertised by
-// CPUID we must ensure it does not contain any sentry state.
+// FeatureSet.
func newAarch64FPState() aarch64FPState {
f := aarch64FPState(newAarch64FPStateSlice())
- initAarch64FPState(f.FloatingPointData())
+ initAarch64FPState(f)
return f
}
@@ -89,8 +95,14 @@ type State struct {
// Our floating point state.
aarch64FPState `state:"wait"`
+ // TLS pointer
+ TPValue uint64
+
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
+
+ // OrigR0 stores the value of register R0.
+ OrigR0 uint64
}
// Proto returns a protobuf representation of the system registers in State.
@@ -136,10 +148,12 @@ func (s State) Proto() *rpb.Registers {
// Fork creates and returns an identical copy of the state.
func (s *State) Fork() State {
- // TODO(gvisor.dev/issue/1238): floating-point is not supported.
return State{
- Regs: s.Regs,
- FeatureSet: s.FeatureSet,
+ Regs: s.Regs,
+ aarch64FPState: s.aarch64FPState.fork(),
+ TPValue: s.TPValue,
+ FeatureSet: s.FeatureSet,
+ OrigR0: s.OrigR0,
}
}
@@ -249,6 +263,7 @@ func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
const (
_NT_PRSTATUS = 1
_NT_PRFPREG = 2
+ _NT_ARM_TLS = 0x401
)
// PtraceGetRegSet implements Context.PtraceGetRegSet.
@@ -288,8 +303,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context {
case ARM64:
return &context64{
State{
- FeatureSet: fs,
+ aarch64FPState: newAarch64FPState(),
+ FeatureSet: fs,
},
+ []aarch64FPState(nil),
}
}
panic(fmt.Sprintf("unknown architecture %v", arch))
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index ac98897b5..db99c5acb 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -53,6 +53,11 @@ const (
preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
)
+var (
+ // CPUIDInstruction doesn't exist on ARM64.
+ CPUIDInstruction = []byte{}
+)
+
// These constants are selected as heuristics to help make the Platform's
// potentially limited address space conform as closely to Linux as possible.
const (
@@ -68,6 +73,7 @@ const (
// context64 represents an ARM64 context.
type context64 struct {
State
+ sigFPState []aarch64FPState // fpstate to be restored on sigreturn.
}
// Arch implements Context.Arch.
@@ -75,10 +81,19 @@ func (c *context64) Arch() Arch {
return ARM64
}
+func (c *context64) copySigFPState() []aarch64FPState {
+ var sigfps []aarch64FPState
+ for _, s := range c.sigFPState {
+ sigfps = append(sigfps, s.fork())
+ }
+ return sigfps
+}
+
// Fork returns an exact copy of this context.
func (c *context64) Fork() Context {
return &context64{
- State: c.State.Fork(),
+ State: c.State.Fork(),
+ sigFPState: c.copySigFPState(),
}
}
@@ -125,16 +140,17 @@ func (c *context64) SetStack(value uintptr) {
// TLS returns the current TLS pointer.
func (c *context64) TLS() uintptr {
- // TODO(gvisor.dev/issue/1238): TLS is not supported.
- // MRS_TPIDR_EL0
- return 0
+ return uintptr(c.TPValue)
}
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
func (c *context64) SetTLS(value uintptr) bool {
- // TODO(gvisor.dev/issue/1238): TLS is not supported.
- // MSR_TPIDR_EL0
- return false
+ if value >= uintptr(maxAddr64) {
+ return false
+ }
+
+ c.TPValue = uint64(value)
+ return true
}
// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 4f4cc46a8..0c1db4b13 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -30,14 +30,29 @@ type SignalContext64 struct {
Sp uint64
Pc uint64
Pstate uint64
- _pad [8]byte // __attribute__((__aligned__(16)))
- Reserved [4096]uint8
+ _pad [8]byte // __attribute__((__aligned__(16)))
+ Fpsimd64 FpsimdContext // size = 528
+ Reserved [3568]uint8
+}
+
+type aarch64Ctx struct {
+ Magic uint32
+ Size uint32
+}
+
+// FpsimdContext is equivalent to struct fpsimd_context on arm64
+// (arch/arm64/include/uapi/asm/sigcontext.h).
+type FpsimdContext struct {
+ Head aarch64Ctx
+ Fpsr uint32
+ Fpcr uint32
+ Vregs [64]uint64 // actually [32]uint128
}
// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
type UContext64 struct {
Flags uint64
- Link *UContext64
+ Link uint64
Stack SignalStack
Sigset linux.SignalSet
// glibc uses a 1024-bit sigset_t
diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go
index 00d5ef461..dc13b6124 100644
--- a/pkg/sentry/arch/syscalls_arm64.go
+++ b/pkg/sentry/arch/syscalls_arm64.go
@@ -50,13 +50,21 @@ func (c *context64) SyscallArgs() SyscallArguments {
}
// RestartSyscall implements Context.RestartSyscall.
+// Prepare for system call restart, OrigR0 will be restored to R0.
+// Please see the linux code as reference:
+// arch/arm64/kernel/signal.c:do_signal()
func (c *context64) RestartSyscall() {
c.Regs.Pc -= SyscallWidth
- c.Regs.Regs[8] = uint64(restartSyscallNr)
+ // R0 will be backed up into OrigR0 when entering doSyscall().
+ // Please see the linux code as reference:
+ // arch/arm64/kernel/syscall.c:el0_svc_common().
+ // Here we restore it back.
+ c.Regs.Regs[0] = uint64(c.OrigR0)
}
// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
func (c *context64) RestartSyscallWithRestartBlock() {
c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[0] = uint64(c.OrigR0)
c.Regs.Regs[8] = uint64(restartSyscallNr)
}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 151808911..663e51989 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -117,9 +117,9 @@ func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error {
return nil
}
-// Goroutine is an RPC stub which dumps out the stack trace for all running
-// goroutines.
-func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error {
+// GoroutineProfile is an RPC stub which dumps out the stack trace for all
+// running goroutines.
+func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
return errNoOutput
}
@@ -131,6 +131,34 @@ func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error {
return nil
}
+// BlockProfile is an RPC stub which dumps out the stack trace that led to
+// blocking on synchronization primitives.
+func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("block").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
+// MutexProfile is an RPC stub which dumps out the stack trace of holders of
+// contended mutexes.
+func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
// StartTrace is an RPC stub which starts collection of an execution trace.
func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9b6bb26d0..9379a4d7b 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -26,6 +26,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 7e66c29b0..acbd401a0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -124,10 +125,12 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+ }
- "net": newDirectory(ctx, map[string]*fs.Inode{
+ if isNetTunSupported(inet.StackFromContext(ctx)) {
+ contents["net"] = newDirectory(ctx, map[string]*fs.Inode{
"tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
- }, msrc),
+ }, msrc)
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 755644488..dc7ad075a 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/syserror"
@@ -168,3 +169,9 @@ func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.Eve
func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
fops.device.EventUnregister(e)
}
+
+// isNetTunSupported returns whether /dev/net/tun device is supported for s.
+func isNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index e0b32e1c1..0266a5287 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -17,7 +17,6 @@ package fs
import (
"fmt"
"path"
- "sort"
"sync/atomic"
"syscall"
@@ -121,9 +120,6 @@ type Dirent struct {
// deleted may be set atomically when removed.
deleted int32
- // frozen indicates this entry can't walk to unknown nodes.
- frozen bool
-
// mounted is true if Dirent is a mount point, similar to include/linux/dcache.h:DCACHE_MOUNTED.
mounted bool
@@ -253,8 +249,7 @@ func (d *Dirent) IsNegative() bool {
return d.Inode == nil
}
-// hashChild will hash child into the children list of its new parent d, carrying over
-// any "frozen" state from d.
+// hashChild will hash child into the children list of its new parent d.
//
// Returns (*WeakRef, true) if hashing child caused a Dirent to be unhashed. The caller must
// validate the returned unhashed weak reference. Common cases:
@@ -282,9 +277,6 @@ func (d *Dirent) hashChild(child *Dirent) (*refs.WeakRef, bool) {
d.IncRef()
}
- // Carry over parent's frozen state.
- child.frozen = d.frozen
-
return d.hashChildParentSet(child)
}
@@ -400,38 +392,6 @@ func (d *Dirent) MountRoot() *Dirent {
return mountRoot
}
-// Freeze prevents this dirent from walking to more nodes. Freeze is applied
-// recursively to all children.
-//
-// If this particular Dirent represents a Virtual node, then Walks and Creates
-// may proceed as before.
-//
-// Freeze can only be called before the application starts running, otherwise
-// the root it might be out of sync with the application root if modified by
-// sys_chroot.
-func (d *Dirent) Freeze() {
- d.mu.Lock()
- defer d.mu.Unlock()
- if d.frozen {
- // Already frozen.
- return
- }
- d.frozen = true
-
- // Take a reference when freezing.
- for _, w := range d.children {
- if child := w.Get(); child != nil {
- // NOTE: We would normally drop the reference here. But
- // instead we're hanging on to it.
- ch := child.(*Dirent)
- ch.Freeze()
- }
- }
-
- // Drop all expired weak references.
- d.flush()
-}
-
// descendantOf returns true if the receiver dirent is equal to, or a
// descendant of, the argument dirent.
//
@@ -524,11 +484,6 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
w.Drop()
}
- // Are we allowed to do the lookup?
- if d.frozen && !d.Inode.IsVirtual() {
- return nil, syscall.ENOENT
- }
-
// Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be
// expensive, if possible release the lock and re-acquire it.
if walkMayUnlock {
@@ -659,11 +614,6 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi
return nil, syscall.EEXIST
}
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return nil, syscall.ENOENT
- }
-
// Try the create. We need to trust the file system to return EEXIST (or something
// that will translate to EEXIST) if name already exists.
file, err := d.Inode.Create(ctx, d, name, flags, perms)
@@ -727,11 +677,6 @@ func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, c
return syscall.EEXIST
}
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Remove any negative Dirent. We've already asserted above with d.exists
// that the only thing remaining here can be a negative Dirent.
if w, ok := d.children[name]; ok {
@@ -862,49 +807,6 @@ func (d *Dirent) GetDotAttrs(root *Dirent) (DentAttr, DentAttr) {
return dot, dot
}
-// readdirFrozen returns readdir results based solely on the frozen children.
-func (d *Dirent) readdirFrozen(root *Dirent, offset int64, dirCtx *DirCtx) (int64, error) {
- // Collect attrs for "." and "..".
- attrs := make(map[string]DentAttr)
- names := []string{".", ".."}
- attrs["."], attrs[".."] = d.GetDotAttrs(root)
-
- // Get info from all children.
- d.mu.Lock()
- defer d.mu.Unlock()
- for name, w := range d.children {
- if child := w.Get(); child != nil {
- defer child.DecRef()
-
- // Skip negative children.
- if child.(*Dirent).IsNegative() {
- continue
- }
-
- sattr := child.(*Dirent).Inode.StableAttr
- attrs[name] = DentAttr{
- Type: sattr.Type,
- InodeID: sattr.InodeID,
- }
- names = append(names, name)
- }
- }
-
- sort.Strings(names)
-
- if int(offset) >= len(names) {
- return offset, nil
- }
- names = names[int(offset):]
- for _, name := range names {
- if err := dirCtx.DirEmit(name, attrs[name]); err != nil {
- return offset, err
- }
- offset++
- }
- return offset, nil
-}
-
// DirIterator is an open directory containing directory entries that can be read.
type DirIterator interface {
// IterateDir emits directory entries by calling dirCtx.EmitDir, beginning
@@ -964,10 +866,6 @@ func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent,
return offset, nil
}
- if d.frozen {
- return d.readdirFrozen(root, offset, dirCtx)
- }
-
// Collect attrs for "." and "..".
dot, dotdot := d.GetDotAttrs(root)
@@ -1068,11 +966,6 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err
return nil, syserror.EINVAL
}
- // Are we frozen?
- if d.parent.frozen && !d.parent.Inode.IsVirtual() {
- return nil, syserror.ENOENT
- }
-
// Dirent that'll replace d.
//
// Note that NewDirent returns with one reference taken; the reference
@@ -1101,11 +994,6 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
return syserror.ENOENT
}
- // Are we frozen?
- if d.parent.frozen && !d.parent.Inode.IsVirtual() {
- return syserror.ENOENT
- }
-
// Remount our former child in its place.
//
// As replacement used to be our child, it must already have the right
@@ -1135,11 +1023,6 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
unlock := d.lockDirectory()
defer unlock()
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Try to walk to the node.
child, err := d.walk(ctx, root, name, false /* may unlock */)
if err != nil {
@@ -1201,11 +1084,6 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
unlock := d.lockDirectory()
defer unlock()
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Check for dots.
if name == "." {
// Rejected as the last component by rmdir(2).
@@ -1519,15 +1397,6 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
return err
}
- // Are we frozen?
- // TODO(jamieliu): Is this the right errno?
- if oldParent.frozen && !oldParent.Inode.IsVirtual() {
- return syscall.ENOENT
- }
- if newParent.frozen && !newParent.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Do we have general permission to remove from oldParent and
// create/replace in newParent?
if err := oldParent.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
index 25514ace4..33de32c69 100644
--- a/pkg/sentry/fs/dirent_cache.go
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -101,8 +101,6 @@ func (c *DirentCache) remove(d *Dirent) {
panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d))
}
c.list.Remove(d)
- d.SetPrev(nil)
- d.SetNext(nil)
d.DecRef()
c.currentSize--
if c.limit != nil {
diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go
index a76d87e3a..1971cc680 100644
--- a/pkg/sentry/fs/file_overlay_test.go
+++ b/pkg/sentry/fs/file_overlay_test.go
@@ -175,90 +175,6 @@ func TestReaddirRevalidation(t *testing.T) {
}
}
-// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with
-// a frozen dirent tree does not make Readdir calls to the underlying files.
-// This is a regression test for b/114808269.
-func TestReaddirOverlayFrozen(t *testing.T) {
- ctx := contexttest.Context(t)
-
- // Create an overlay with two directories, each with two files.
- upper := newTestRamfsDir(ctx, []dirContent{{name: "upper-file1"}, {name: "upper-file2"}}, nil)
- lower := newTestRamfsDir(ctx, []dirContent{{name: "lower-file1"}, {name: "lower-file2"}}, nil)
- overlayInode := fs.NewTestOverlayDir(ctx, upper, lower, false)
-
- // Set that overlay as the root.
- root := fs.NewDirent(ctx, overlayInode, "root")
- ctx = &rootContext{
- Context: ctx,
- root: root,
- }
-
- // Check that calling Readdir on the root now returns all 4 files (2
- // from each layer in the overlay).
- rootFile, err := root.Inode.GetFile(ctx, root, fs.FileFlags{Read: true})
- if err != nil {
- t.Fatalf("root.Inode.GetFile failed: %v", err)
- }
- defer rootFile.DecRef()
- ser := &fs.CollectEntriesSerializer{}
- if err := rootFile.Readdir(ctx, ser); err != nil {
- t.Fatalf("rootFile.Readdir failed: %v", err)
- }
- if got, want := ser.Order, []string{".", "..", "lower-file1", "lower-file2", "upper-file1", "upper-file2"}; !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got names %v, want %v", got, want)
- }
-
- // Readdir should have been called on upper and lower.
- upperDir := upper.InodeOperations.(*dir)
- lowerDir := lower.InodeOperations.(*dir)
- if !upperDir.ReaddirCalled {
- t.Errorf("upperDir.ReaddirCalled got %v, want true", upperDir.ReaddirCalled)
- }
- if !lowerDir.ReaddirCalled {
- t.Errorf("lowerDir.ReaddirCalled got %v, want true", lowerDir.ReaddirCalled)
- }
-
- // Reset.
- upperDir.ReaddirCalled = false
- lowerDir.ReaddirCalled = false
-
- // Take references on "upper-file1" and "lower-file1", pinning them in
- // the dirent tree.
- for _, name := range []string{"upper-file1", "lower-file1"} {
- if _, err := root.Walk(ctx, root, name); err != nil {
- t.Fatalf("root.Walk(%q) failed: %v", name, err)
- }
- // Don't drop a reference on the returned dirent so that it
- // will stay in the tree.
- }
-
- // Freeze the dirent tree.
- root.Freeze()
-
- // Seek back to the beginning of the file.
- if _, err := rootFile.Seek(ctx, fs.SeekSet, 0); err != nil {
- t.Fatalf("error seeking to beginning of directory: %v", err)
- }
-
- // Calling Readdir on the root now will return only the pinned
- // children.
- ser = &fs.CollectEntriesSerializer{}
- if err := rootFile.Readdir(ctx, ser); err != nil {
- t.Fatalf("rootFile.Readdir failed: %v", err)
- }
- if got, want := ser.Order, []string{".", "..", "lower-file1", "upper-file1"}; !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got names %v, want %v", got, want)
- }
-
- // Readdir should NOT have been called on upper or lower.
- if upperDir.ReaddirCalled {
- t.Errorf("upperDir.ReaddirCalled got %v, want false", upperDir.ReaddirCalled)
- }
- if lowerDir.ReaddirCalled {
- t.Errorf("lowerDir.ReaddirCalled got %v, want false", lowerDir.ReaddirCalled)
- }
-}
-
type rootContext struct {
context.Context
root *fs.Dirent
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index daecc4ffe..1922ff08c 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -259,8 +259,8 @@ func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, ui
// RemoveXattr implements fs.InodeOperations.RemoveXattr.
func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error {
- i.mu.RLock()
- defer i.mu.RUnlock()
+ i.mu.Lock()
+ defer i.mu.Unlock()
if _, ok := i.xattrs[name]; ok {
delete(i.xattrs, name)
return nil
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 21003ea45..011625c80 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -10,7 +10,7 @@ go_library(
"descriptor_state.go",
"device.go",
"file.go",
- "fs.go",
+ "host.go",
"inode.go",
"inode_state.go",
"ioctl_unsafe.go",
@@ -62,14 +62,12 @@ go_test(
size = "small",
srcs = [
"descriptor_test.go",
- "fs_test.go",
"inode_test.go",
"socket_test.go",
"wait_test.go",
],
library = ":host",
deps = [
- "//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/sentry/contexttest",
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
index 1658979fc..cd84e1337 100644
--- a/pkg/sentry/fs/host/control.go
+++ b/pkg/sentry/fs/host/control.go
@@ -32,6 +32,8 @@ func newSCMRights(fds []int) control.SCMRights {
}
// Files implements control.SCMRights.Files.
+//
+// TODO(gvisor.dev/issue/2017): Port to VFS2.
func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFiles, bool) {
n := max
var trunc bool
diff --git a/pkg/sentry/fs/host/descriptor.go b/pkg/sentry/fs/host/descriptor.go
index 2a4d1b291..cfdce6a74 100644
--- a/pkg/sentry/fs/host/descriptor.go
+++ b/pkg/sentry/fs/host/descriptor.go
@@ -16,7 +16,6 @@ package host
import (
"fmt"
- "path"
"syscall"
"gvisor.dev/gvisor/pkg/fdnotifier"
@@ -28,12 +27,9 @@ import (
//
// +stateify savable
type descriptor struct {
- // donated is true if the host fd was donated by another process.
- donated bool
-
// If origFD >= 0, it is the host fd that this file was originally created
// from, which must be available at time of restore. The FD can be closed
- // after descriptor is created. Only set if donated is true.
+ // after descriptor is created.
origFD int
// wouldBlock is true if value (below) points to a file that can
@@ -41,15 +37,13 @@ type descriptor struct {
wouldBlock bool
// value is the wrapped host fd. It is never saved or restored
- // directly. How it is restored depends on whether it was
- // donated and the fs.MountSource it was originally
- // opened/created from.
+ // directly.
value int `state:"nosave"`
}
// newDescriptor returns a wrapped host file descriptor. On success,
// the descriptor is registered for event notifications with queue.
-func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
+func newDescriptor(fd int, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
ownedFD := fd
origFD := -1
if saveable {
@@ -69,7 +63,6 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *
}
}
return &descriptor{
- donated: donated,
origFD: origFD,
wouldBlock: wouldBlock,
value: ownedFD,
@@ -77,25 +70,11 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *
}
// initAfterLoad initializes the value of the descriptor after Load.
-func (d *descriptor) initAfterLoad(mo *superOperations, id uint64, queue *waiter.Queue) error {
- if d.donated {
- var err error
- d.value, err = syscall.Dup(d.origFD)
- if err != nil {
- return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
- }
- } else {
- name, ok := mo.inodeMappings[id]
- if !ok {
- return fmt.Errorf("failed to find path for inode number %d", id)
- }
- fullpath := path.Join(mo.root, name)
-
- var err error
- d.value, err = open(nil, fullpath)
- if err != nil {
- return fmt.Errorf("failed to open %q: %v", fullpath, err)
- }
+func (d *descriptor) initAfterLoad(id uint64, queue *waiter.Queue) error {
+ var err error
+ d.value, err = syscall.Dup(d.origFD)
+ if err != nil {
+ return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
}
if d.wouldBlock {
if err := syscall.SetNonblock(d.value, true); err != nil {
diff --git a/pkg/sentry/fs/host/descriptor_state.go b/pkg/sentry/fs/host/descriptor_state.go
index 8167390a9..e880582ab 100644
--- a/pkg/sentry/fs/host/descriptor_state.go
+++ b/pkg/sentry/fs/host/descriptor_state.go
@@ -16,7 +16,7 @@ package host
// beforeSave is invoked by stateify.
func (d *descriptor) beforeSave() {
- if d.donated && d.origFD < 0 {
+ if d.origFD < 0 {
panic("donated file descriptor cannot be saved")
}
}
diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go
index 4205981f5..d8e4605b6 100644
--- a/pkg/sentry/fs/host/descriptor_test.go
+++ b/pkg/sentry/fs/host/descriptor_test.go
@@ -47,10 +47,10 @@ func TestDescriptorRelease(t *testing.T) {
// FD ownership is transferred to the descritor.
queue := &waiter.Queue{}
- d, err := newDescriptor(fd, false /* donated*/, tc.saveable, tc.wouldBlock, queue)
+ d, err := newDescriptor(fd, tc.saveable, tc.wouldBlock, queue)
if err != nil {
syscall.Close(fd)
- t.Fatalf("newDescriptor(%d, %t, false, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err)
+ t.Fatalf("newDescriptor(%d, %t, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err)
}
if tc.saveable {
if d.origFD < 0 {
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
index e08f56d04..034862694 100644
--- a/pkg/sentry/fs/host/file.go
+++ b/pkg/sentry/fs/host/file.go
@@ -101,8 +101,8 @@ func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner
})
return s, nil
default:
- msrc := newMountSource(ctx, "/", mounter, &Filesystem{}, fs.MountSourceFlags{}, false /* dontTranslateOwnership */)
- inode, err := newInode(ctx, msrc, donated, saveable, true /* donated */)
+ msrc := fs.NewNonCachingMountSource(ctx, &filesystem{}, fs.MountSourceFlags{})
+ inode, err := newInode(ctx, msrc, donated, saveable)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fs/host/fs.go b/pkg/sentry/fs/host/fs.go
deleted file mode 100644
index d3e8e3a36..000000000
--- a/pkg/sentry/fs/host/fs.go
+++ /dev/null
@@ -1,339 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package host implements an fs.Filesystem for files backed by host
-// file descriptors.
-package host
-
-import (
- "fmt"
- "path"
- "path/filepath"
- "strconv"
- "strings"
-
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/fs"
-)
-
-// FilesystemName is the name under which Filesystem is registered.
-const FilesystemName = "whitelistfs"
-
-const (
- // whitelistKey is the mount option containing a comma-separated list
- // of host paths to whitelist.
- whitelistKey = "whitelist"
-
- // rootPathKey is the mount option containing the root path of the
- // mount.
- rootPathKey = "root"
-
- // dontTranslateOwnershipKey is the key to superOperations.dontTranslateOwnership.
- dontTranslateOwnershipKey = "dont_translate_ownership"
-)
-
-// maxTraversals determines link traversals in building the whitelist.
-const maxTraversals = 10
-
-// Filesystem is a pseudo file system that is only available during the setup
-// to lock down the configurations. This filesystem should only be mounted at root.
-//
-// Think twice before exposing this to applications.
-//
-// +stateify savable
-type Filesystem struct {
- // whitelist is a set of host paths to whitelist.
- paths []string
-}
-
-var _ fs.Filesystem = (*Filesystem)(nil)
-
-// Name is the identifier of this file system.
-func (*Filesystem) Name() string {
- return FilesystemName
-}
-
-// AllowUserMount prohibits users from using mount(2) with this file system.
-func (*Filesystem) AllowUserMount() bool {
- return false
-}
-
-// AllowUserList allows this filesystem to be listed in /proc/filesystems.
-func (*Filesystem) AllowUserList() bool {
- return true
-}
-
-// Flags returns that there is nothing special about this file system.
-func (*Filesystem) Flags() fs.FilesystemFlags {
- return 0
-}
-
-// Mount returns an fs.Inode exposing the host file system. It is intended to be locked
-// down in PreExec below.
-func (f *Filesystem) Mount(ctx context.Context, _ string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
- // Parse generic comma-separated key=value options.
- options := fs.GenericMountSourceOptions(data)
-
- // Grab the whitelist if one was specified.
- // TODO(edahlgren/mpratt/hzy): require another option "testonly" in order to allow
- // no whitelist.
- if wl, ok := options[whitelistKey]; ok {
- f.paths = strings.Split(wl, "|")
- delete(options, whitelistKey)
- }
-
- // If the rootPath was set, use it. Othewise default to the root of the
- // host fs.
- rootPath := "/"
- if rp, ok := options[rootPathKey]; ok {
- rootPath = rp
- delete(options, rootPathKey)
-
- // We must relativize the whitelisted paths to the new root.
- for i, p := range f.paths {
- rel, err := filepath.Rel(rootPath, p)
- if err != nil {
- return nil, fmt.Errorf("whitelist path %q must be a child of root path %q", p, rootPath)
- }
- f.paths[i] = path.Join("/", rel)
- }
- }
- fd, err := open(nil, rootPath)
- if err != nil {
- return nil, fmt.Errorf("failed to find root: %v", err)
- }
-
- var dontTranslateOwnership bool
- if v, ok := options[dontTranslateOwnershipKey]; ok {
- b, err := strconv.ParseBool(v)
- if err != nil {
- return nil, fmt.Errorf("invalid value for %q: %v", dontTranslateOwnershipKey, err)
- }
- dontTranslateOwnership = b
- delete(options, dontTranslateOwnershipKey)
- }
-
- // Fail if the caller passed us more options than we know about.
- if len(options) > 0 {
- return nil, fmt.Errorf("unsupported mount options: %v", options)
- }
-
- // The mounting EUID/EGID will be cached by this file system. This will
- // be used to assign ownership to files that we own.
- owner := fs.FileOwnerFromContext(ctx)
-
- // Construct the host file system mount and inode.
- msrc := newMountSource(ctx, rootPath, owner, f, flags, dontTranslateOwnership)
- return newInode(ctx, msrc, fd, false /* saveable */, false /* donated */)
-}
-
-// InstallWhitelist locks down the MountNamespace to only the currently installed
-// Dirents and the given paths.
-func (f *Filesystem) InstallWhitelist(ctx context.Context, m *fs.MountNamespace) error {
- return installWhitelist(ctx, m, f.paths)
-}
-
-func installWhitelist(ctx context.Context, m *fs.MountNamespace, paths []string) error {
- if len(paths) == 0 || (len(paths) == 1 && paths[0] == "") {
- // Warning will be logged during filter installation if the empty
- // whitelist matters (allows for host file access).
- return nil
- }
-
- // Done tracks entries already added.
- done := make(map[string]bool)
- root := m.Root()
- defer root.DecRef()
-
- for i := 0; i < len(paths); i++ {
- // Make sure the path is absolute. This is a sanity check.
- if !path.IsAbs(paths[i]) {
- return fmt.Errorf("path %q is not absolute", paths[i])
- }
-
- // We need to add all the intermediate paths, in case one of
- // them is a symlink that needs to be resolved.
- for j := 1; j <= len(paths[i]); j++ {
- if j < len(paths[i]) && paths[i][j] != '/' {
- continue
- }
- current := paths[i][:j]
-
- // Lookup the given component in the tree.
- remainingTraversals := uint(maxTraversals)
- d, err := m.FindLink(ctx, root, nil, current, &remainingTraversals)
- if err != nil {
- log.Warningf("populate failed for %q: %v", current, err)
- continue
- }
-
- // It's critical that this DecRef happens after the
- // freeze below. This ensures that the dentry is in
- // place to be frozen. Otherwise, we freeze without
- // these entries.
- defer d.DecRef()
-
- // Expand the last component if necessary.
- if current == paths[i] {
- // Is it a directory or symlink?
- sattr := d.Inode.StableAttr
- if fs.IsDir(sattr) {
- for name := range childDentAttrs(ctx, d) {
- paths = append(paths, path.Join(current, name))
- }
- }
- if fs.IsSymlink(sattr) {
- // Only expand symlinks once. The
- // folder structure may contain
- // recursive symlinks and we don't want
- // to end up infinitely expanding this
- // symlink. This is safe because this
- // is the last component. If a later
- // path wants to symlink something
- // beneath this symlink that will still
- // be handled by the FindLink above.
- if done[current] {
- continue
- }
-
- s, err := d.Inode.Readlink(ctx)
- if err != nil {
- log.Warningf("readlink failed for %q: %v", current, err)
- continue
- }
- if path.IsAbs(s) {
- paths = append(paths, s)
- } else {
- target := path.Join(path.Dir(current), s)
- paths = append(paths, target)
- }
- }
- }
-
- // Only report this one once even though we may look
- // it up more than once. If we whitelist /a/b,/a then
- // /a will be "done" when it is looked up for /a/b,
- // however we still need to expand all of its contents
- // when whitelisting /a.
- if !done[current] {
- log.Debugf("whitelisted: %s", current)
- }
- done[current] = true
- }
- }
-
- // Freeze the mount tree in place. This prevents any new paths from
- // being opened and any old ones from being removed. If we do provide
- // tmpfs mounts, we'll want to freeze/thaw those separately.
- m.Freeze()
- return nil
-}
-
-func childDentAttrs(ctx context.Context, d *fs.Dirent) map[string]fs.DentAttr {
- dirname, _ := d.FullName(nil /* root */)
- dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- log.Warningf("failed to open directory %q: %v", dirname, err)
- return nil
- }
- dir.DecRef()
- var stubSerializer fs.CollectEntriesSerializer
- if err := dir.Readdir(ctx, &stubSerializer); err != nil {
- log.Warningf("failed to iterate on host directory %q: %v", dirname, err)
- return nil
- }
- delete(stubSerializer.Entries, ".")
- delete(stubSerializer.Entries, "..")
- return stubSerializer.Entries
-}
-
-// newMountSource constructs a new host fs.MountSource
-// relative to a root path. The root should match the mount point.
-func newMountSource(ctx context.Context, root string, mounter fs.FileOwner, filesystem fs.Filesystem, flags fs.MountSourceFlags, dontTranslateOwnership bool) *fs.MountSource {
- return fs.NewMountSource(ctx, &superOperations{
- root: root,
- inodeMappings: make(map[uint64]string),
- mounter: mounter,
- dontTranslateOwnership: dontTranslateOwnership,
- }, filesystem, flags)
-}
-
-// superOperations implements fs.MountSourceOperations.
-//
-// +stateify savable
-type superOperations struct {
- fs.SimpleMountSourceOperations
-
- // root is the path of the mount point. All inode mappings
- // are relative to this root.
- root string
-
- // inodeMappings contains mappings of fs.Inodes associated
- // with this MountSource to paths under root.
- inodeMappings map[uint64]string
-
- // mounter is the cached EUID/EGID that mounted this file system.
- mounter fs.FileOwner
-
- // dontTranslateOwnership indicates whether to not translate file
- // ownership.
- //
- // By default, files/directories owned by the sandbox uses UID/GID
- // of the mounter. For files/directories that are not owned by the
- // sandbox, file UID/GID is translated to a UID/GID which cannot
- // be mapped in the sandboxed application's user namespace. The
- // UID/GID will look like the nobody UID/GID (65534) but is not
- // strictly owned by the user "nobody".
- //
- // If whitelistfs is a lower filesystem in an overlay, set
- // dont_translate_ownership=true in mount options.
- dontTranslateOwnership bool
-}
-
-var _ fs.MountSourceOperations = (*superOperations)(nil)
-
-// ResetInodeMappings implements fs.MountSourceOperations.ResetInodeMappings.
-func (m *superOperations) ResetInodeMappings() {
- m.inodeMappings = make(map[uint64]string)
-}
-
-// SaveInodeMapping implements fs.MountSourceOperations.SaveInodeMapping.
-func (m *superOperations) SaveInodeMapping(inode *fs.Inode, path string) {
- // This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
- // because overlay copyUp may have changed them out from under us.
- // So much for "immutable".
- sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr
- m.inodeMappings[sattr.InodeID] = path
-}
-
-// Keep implements fs.MountSourceOperations.Keep.
-//
-// TODO(b/72455313,b/77596690): It is possible to change the permissions on a
-// host file while it is in the dirent cache (say from RO to RW), but it is not
-// possible to re-open the file with more relaxed permissions, since the host
-// FD is already open and stored in the inode.
-//
-// Using the dirent LRU cache increases the odds that this bug is encountered.
-// Since host file access is relatively fast anyways, we disable the LRU cache
-// for host fs files. Once we can properly deal with permissions changes and
-// re-opening host files, we should revisit whether or not to make use of the
-// LRU cache.
-func (*superOperations) Keep(*fs.Dirent) bool {
- return false
-}
-
-func init() {
- fs.RegisterFilesystem(&Filesystem{})
-}
diff --git a/pkg/sentry/fs/host/fs_test.go b/pkg/sentry/fs/host/fs_test.go
deleted file mode 100644
index 3111d2df9..000000000
--- a/pkg/sentry/fs/host/fs_test.go
+++ /dev/null
@@ -1,380 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package host
-
-import (
- "fmt"
- "io/ioutil"
- "os"
- "path"
- "reflect"
- "sort"
- "testing"
-
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
-)
-
-// newTestMountNamespace creates a MountNamespace with a ramfs root.
-// It returns the host folder created, which should be removed when done.
-func newTestMountNamespace(t *testing.T) (*fs.MountNamespace, string, error) {
- p, err := ioutil.TempDir("", "root")
- if err != nil {
- return nil, "", err
- }
-
- fd, err := open(nil, p)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- ctx := contexttest.Context(t)
- root, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- mm, err := fs.NewMountNamespace(ctx, root)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- return mm, p, nil
-}
-
-// createTestDirs populates the root with some test files and directories.
-// /a/a1.txt
-// /a/a2.txt
-// /b/b1.txt
-// /b/c/c1.txt
-// /symlinks/normal.txt
-// /symlinks/to_normal.txt -> /symlinks/normal.txt
-// /symlinks/recursive -> /symlinks
-func createTestDirs(ctx context.Context, t *testing.T, m *fs.MountNamespace) error {
- r := m.Root()
- defer r.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "a", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- a, err := r.Walk(ctx, r, "a")
- if err != nil {
- return err
- }
- defer a.DecRef()
-
- a1, err := a.Create(ctx, r, "a1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- a1.DecRef()
-
- a2, err := a.Create(ctx, r, "a2.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- a2.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "b", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- b, err := r.Walk(ctx, r, "b")
- if err != nil {
- return err
- }
- defer b.DecRef()
-
- b1, err := b.Create(ctx, r, "b1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- b1.DecRef()
-
- if err := b.CreateDirectory(ctx, r, "c", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- c, err := b.Walk(ctx, r, "c")
- if err != nil {
- return err
- }
- defer c.DecRef()
-
- c1, err := c.Create(ctx, r, "c1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- c1.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "symlinks", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- symlinks, err := r.Walk(ctx, r, "symlinks")
- if err != nil {
- return err
- }
- defer symlinks.DecRef()
-
- normal, err := symlinks.Create(ctx, r, "normal.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- normal.DecRef()
-
- if err := symlinks.CreateLink(ctx, r, "/symlinks/normal.txt", "to_normal.txt"); err != nil {
- return err
- }
-
- return symlinks.CreateLink(ctx, r, "/symlinks", "recursive")
-}
-
-// allPaths returns a slice of all paths of entries visible in the rootfs.
-func allPaths(ctx context.Context, t *testing.T, m *fs.MountNamespace, base string) ([]string, error) {
- var paths []string
- root := m.Root()
- defer root.DecRef()
-
- maxTraversals := uint(1)
- d, err := m.FindLink(ctx, root, nil, base, &maxTraversals)
- if err != nil {
- t.Logf("FindLink failed for %q", base)
- return paths, err
- }
- defer d.DecRef()
-
- if fs.IsDir(d.Inode.StableAttr) {
- dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- return nil, fmt.Errorf("failed to open directory %q: %v", base, err)
- }
- iter, ok := dir.FileOperations.(fs.DirIterator)
- if !ok {
- return nil, fmt.Errorf("cannot directly iterate on host directory %q", base)
- }
- dirCtx := &fs.DirCtx{
- Serializer: noopDentrySerializer{},
- }
- if _, err := fs.DirentReaddir(ctx, d, iter, root, dirCtx, 0); err != nil {
- return nil, err
- }
- for name := range dirCtx.DentAttrs() {
- if name == "." || name == ".." {
- continue
- }
-
- fullName := path.Join(base, name)
- paths = append(paths, fullName)
-
- // Recurse.
- subpaths, err := allPaths(ctx, t, m, fullName)
- if err != nil {
- return paths, err
- }
- paths = append(paths, subpaths...)
- }
- }
-
- return paths, nil
-}
-
-type noopDentrySerializer struct{}
-
-func (noopDentrySerializer) CopyOut(string, fs.DentAttr) error {
- return nil
-}
-func (noopDentrySerializer) Written() int {
- return 4096
-}
-
-// pathsEqual returns true if the two string slices contain the same entries.
-func pathsEqual(got, want []string) bool {
- sort.Strings(got)
- sort.Strings(want)
-
- if len(got) != len(want) {
- return false
- }
-
- for i := range got {
- if got[i] != want[i] {
- return false
- }
- }
-
- return true
-}
-
-func TestWhitelist(t *testing.T) {
- for _, test := range []struct {
- // description of the test.
- desc string
- // paths are the paths to whitelist
- paths []string
- // want are all of the directory entries that should be
- // visible (nothing beyond this set should be visible).
- want []string
- }{
- {
- desc: "root",
- paths: []string{"/"},
- want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt", "/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt", "/symlinks/recursive"},
- },
- {
- desc: "top-level directories",
- paths: []string{"/a", "/b"},
- want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "nested directories (1/2)",
- paths: []string{"/b", "/b/c"},
- want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "nested directories (2/2)",
- paths: []string{"/b/c", "/b"},
- want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "single file",
- paths: []string{"/b/c/c1.txt"},
- want: []string{"/b", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "single file and directory",
- paths: []string{"/a/a1.txt", "/b/c"},
- want: []string{"/a", "/a/a1.txt", "/b", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "symlink",
- paths: []string{"/symlinks/to_normal.txt"},
- want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt"},
- },
- {
- desc: "recursive symlink",
- paths: []string{"/symlinks/recursive/normal.txt"},
- want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/recursive"},
- },
- } {
- t.Run(test.desc, func(t *testing.T) {
- m, p, err := newTestMountNamespace(t)
- if err != nil {
- t.Errorf("Failed to create MountNamespace: %v", err)
- }
- defer os.RemoveAll(p)
-
- ctx := withRoot(contexttest.RootContext(t), m.Root())
- if err := createTestDirs(ctx, t, m); err != nil {
- t.Errorf("Failed to create test dirs: %v", err)
- }
-
- if err := installWhitelist(ctx, m, test.paths); err != nil {
- t.Errorf("installWhitelist(%v) err got %v want nil", test.paths, err)
- }
-
- got, err := allPaths(ctx, t, m, "/")
- if err != nil {
- t.Fatalf("Failed to lookup paths (whitelisted: %v): %v", test.paths, err)
- }
-
- if !pathsEqual(got, test.want) {
- t.Errorf("For paths %v got %v want %v", test.paths, got, test.want)
- }
- })
- }
-}
-
-func TestRootPath(t *testing.T) {
- // Create a temp dir, which will be the root of our mounted fs.
- rootPath, err := ioutil.TempDir(os.TempDir(), "root")
- if err != nil {
- t.Fatalf("TempDir failed: %v", err)
- }
- defer os.RemoveAll(rootPath)
-
- // Create two files inside the new root, one which will be whitelisted
- // and one not.
- whitelisted, err := ioutil.TempFile(rootPath, "white")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- if _, err := ioutil.TempFile(rootPath, "black"); err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
-
- // Create a mount with a root path and single whitelisted file.
- hostFS := &Filesystem{}
- ctx := contexttest.Context(t)
- data := fmt.Sprintf("%s=%s,%s=%s", rootPathKey, rootPath, whitelistKey, whitelisted.Name())
- inode, err := hostFS.Mount(ctx, "", fs.MountSourceFlags{}, data, nil)
- if err != nil {
- t.Fatalf("Mount failed: %v", err)
- }
- mm, err := fs.NewMountNamespace(ctx, inode)
- if err != nil {
- t.Fatalf("NewMountNamespace failed: %v", err)
- }
- if err := hostFS.InstallWhitelist(ctx, mm); err != nil {
- t.Fatalf("InstallWhitelist failed: %v", err)
- }
-
- // Get the contents of the root directory.
- rootDir := mm.Root()
- rctx := withRoot(ctx, rootDir)
- f, err := rootDir.Inode.GetFile(rctx, rootDir, fs.FileFlags{})
- if err != nil {
- t.Fatalf("GetFile failed: %v", err)
- }
- c := &fs.CollectEntriesSerializer{}
- if err := f.Readdir(rctx, c); err != nil {
- t.Fatalf("Readdir failed: %v", err)
- }
-
- // We should have only our whitelisted file, plus the dots.
- want := []string{path.Base(whitelisted.Name()), ".", ".."}
- got := c.Order
- sort.Strings(want)
- sort.Strings(got)
- if !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got %v, wanted %v", got, want)
- }
-}
-
-type rootContext struct {
- context.Context
- root *fs.Dirent
-}
-
-// withRoot returns a copy of ctx with the given root.
-func withRoot(ctx context.Context, root *fs.Dirent) context.Context {
- return &rootContext{
- Context: ctx,
- root: root,
- }
-}
-
-// Value implements Context.Value.
-func (rc rootContext) Value(key interface{}) interface{} {
- switch key {
- case fs.CtxRoot:
- rc.root.IncRef()
- return rc.root
- default:
- return rc.Context.Value(key)
- }
-}
diff --git a/pkg/sentry/fs/host/host.go b/pkg/sentry/fs/host/host.go
new file mode 100644
index 000000000..081ba1dd8
--- /dev/null
+++ b/pkg/sentry/fs/host/host.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host supports file descriptors imported directly.
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// filesystem is a host filesystem.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+const FilesystemName = "host"
+
+// Name is the name of the filesystem.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// Mount returns an error. Mounting hostfs is not allowed.
+func (*filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, dataObj interface{}) (*fs.Inode, error) {
+ return nil, syserror.EPERM
+}
+
+// AllowUserMount prohibits users from using mount(2) with this file system.
+func (*filesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList prohibits this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return false
+}
+
+// Flags returns that there is nothing special about this file system.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 6fa39caab..1da3c0a17 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -17,12 +17,10 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/secio"
- "gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -69,9 +67,6 @@ type inodeOperations struct {
//
// +stateify savable
type inodeFileState struct {
- // Common file system state.
- mops *superOperations `state:"wait"`
-
// descriptor is the backing host FD.
descriptor *descriptor `state:"wait"`
@@ -160,7 +155,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
if err := syscall.Fstat(i.FD(), &s); err != nil {
return fs.UnstableAttr{}, err
}
- return unstableAttr(i.mops, &s), nil
+ return unstableAttr(&s), nil
}
// Allocate implements fsutil.CachedFileObject.Allocate.
@@ -172,7 +167,7 @@ func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error
var _ fs.InodeOperations = (*inodeOperations)(nil)
// newInode returns a new fs.Inode backed by the host FD.
-func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, donated bool) (*fs.Inode, error) {
+func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool) (*fs.Inode, error) {
// Retrieve metadata.
var s syscall.Stat_t
err := syscall.Fstat(fd, &s)
@@ -181,24 +176,17 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool,
}
fileState := &inodeFileState{
- mops: msrc.MountSourceOperations.(*superOperations),
sattr: stableAttr(&s),
}
// Initialize the wrapped host file descriptor.
- fileState.descriptor, err = newDescriptor(
- fd,
- donated,
- saveable,
- wouldBlock(&s),
- &fileState.queue,
- )
+ fileState.descriptor, err = newDescriptor(fd, saveable, wouldBlock(&s), &fileState.queue)
if err != nil {
return nil, err
}
// Build the fs.InodeOperations.
- uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
+ uattr := unstableAttr(&s)
iops := &inodeOperations{
fileState: fileState,
cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
@@ -232,54 +220,23 @@ func (i *inodeOperations) Release(context.Context) {
// Lookup implements fs.InodeOperations.Lookup.
func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
- // Get a new FD relative to i at name.
- fd, err := open(i, name)
- if err != nil {
- if err == syserror.ENOENT {
- return nil, syserror.ENOENT
- }
- return nil, err
- }
-
- inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
- if err != nil {
- return nil, err
- }
-
- // Return the fs.Dirent.
- return fs.NewDirent(ctx, inode, name), nil
+ return nil, syserror.ENOENT
}
// Create implements fs.InodeOperations.Create.
func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
- // Create a file relative to i at name.
- //
- // N.B. We always open this file O_RDWR regardless of flags because a
- // future GetFile might want more access. Open allows this regardless
- // of perm.
- fd, err := openAt(i, name, syscall.O_RDWR|syscall.O_CREAT|syscall.O_EXCL, perm.LinuxMode())
- if err != nil {
- return nil, err
- }
-
- inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
- if err != nil {
- return nil, err
- }
+ return nil, syserror.EPERM
- d := fs.NewDirent(ctx, inode, name)
- defer d.DecRef()
- return inode.GetFile(ctx, d, flags)
}
// CreateDirectory implements fs.InodeOperations.CreateDirectory.
func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
- return syscall.Mkdirat(i.fileState.FD(), name, uint32(perm.LinuxMode()))
+ return syserror.EPERM
}
// CreateLink implements fs.InodeOperations.CreateLink.
func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error {
- return createLink(i.fileState.FD(), oldname, newname)
+ return syserror.EPERM
}
// CreateHardLink implements fs.InodeOperations.CreateHardLink.
@@ -294,25 +251,17 @@ func (*inodeOperations) CreateFifo(context.Context, *fs.Inode, string, fs.FilePe
// Remove implements fs.InodeOperations.Remove.
func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
- return unlinkAt(i.fileState.FD(), name, false /* dir */)
+ return syserror.EPERM
}
// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
- return unlinkAt(i.fileState.FD(), name, true /* dir */)
+ return syserror.EPERM
}
// Rename implements fs.InodeOperations.Rename.
func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
- op, ok := oldParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
- }
- np, ok := newParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
- }
- return syscall.Renameat(op.fileState.FD(), oldName, np.fileState.FD(), newName)
+ return syserror.EPERM
}
// Bind implements fs.InodeOperations.Bind.
@@ -461,69 +410,7 @@ func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
// readdirAll returns all of the directory entries in i.
func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) {
- i.readdirMu.Lock()
- defer i.readdirMu.Unlock()
-
- fd := i.fileState.FD()
-
- // syscall.ReadDirent will use getdents, which will seek the file past
- // the last directory entry. To read the directory entries a second
- // time, we need to seek back to the beginning.
- if _, err := syscall.Seek(fd, 0, 0); err != nil {
- if err == syscall.ESPIPE {
- // All directories should be seekable. If this file
- // isn't seekable, it is not a directory and we should
- // return that more sane error.
- err = syscall.ENOTDIR
- }
- return nil, err
- }
-
- names := make([]string, 0, 100)
- for {
- // Refill the buffer if necessary
- if d.bufp >= d.nbuf {
- d.bufp = 0
- // ReadDirent will just do a sys_getdents64 to the kernel.
- n, err := syscall.ReadDirent(fd, d.buf)
- if err != nil {
- return nil, err
- }
- if n == 0 {
- break // EOF
- }
- d.nbuf = n
- }
-
- var nb int
- // Parse the dirent buffer we just get and return the directory names along
- // with the number of bytes consumed in the buffer.
- nb, _, names = syscall.ParseDirent(d.buf[d.bufp:d.nbuf], -1, names)
- d.bufp += nb
- }
-
- entries := make(map[string]fs.DentAttr)
- for _, filename := range names {
- // Lookup the type and host device and inode.
- stat, lerr := fstatat(fd, filename, linux.AT_SYMLINK_NOFOLLOW)
- if lerr == syscall.ENOENT {
- // File disappeared between readdir and lstat.
- // Just treat it as if it didn't exist.
- continue
- }
-
- // There was a serious problem, we should probably report it.
- if lerr != nil {
- return nil, lerr
- }
-
- entries[filename] = fs.DentAttr{
- Type: nodeType(&stat),
- InodeID: hostFileDevice.Map(device.MultiDeviceKey{
- Device: stat.Dev,
- Inode: stat.Ino,
- }),
- }
- }
- return entries, nil
+ // We only support non-directory file descriptors that have been
+ // imported, so just claim that this isn't a directory, even if it is.
+ return nil, syscall.ENOTDIR
}
diff --git a/pkg/sentry/fs/host/inode_state.go b/pkg/sentry/fs/host/inode_state.go
index 299e0e0b0..1adbd4562 100644
--- a/pkg/sentry/fs/host/inode_state.go
+++ b/pkg/sentry/fs/host/inode_state.go
@@ -18,29 +18,14 @@ import (
"fmt"
"syscall"
- "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
-// beforeSave is invoked by stateify.
-func (i *inodeFileState) beforeSave() {
- if !i.queue.IsEmpty() {
- panic("event queue must be empty")
- }
- if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
- uattr, err := i.unstableAttr(context.Background())
- if err != nil {
- panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.mops.inodeMappings[i.sattr.InodeID], err)})
- }
- i.savedUAttr = &uattr
- }
-}
-
// afterLoad is invoked by stateify.
func (i *inodeFileState) afterLoad() {
// Initialize the descriptor value.
- if err := i.descriptor.initAfterLoad(i.mops, i.sattr.InodeID, &i.queue); err != nil {
+ if err := i.descriptor.initAfterLoad(i.sattr.InodeID, &i.queue); err != nil {
panic(fmt.Sprintf("failed to load value of descriptor: %v", err))
}
@@ -61,19 +46,4 @@ func (i *inodeFileState) afterLoad() {
// change across save and restore, error out.
panic(fs.ErrCorruption{fmt.Errorf("host %s conflict in host device mappings: %s", key, hostFileDevice)})
}
-
- if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
- env, ok := fs.CurrentRestoreEnvironment()
- if !ok {
- panic("missing restore environment")
- }
- uattr := unstableAttr(i.mops, &s)
- if env.ValidateFileSize && uattr.Size != i.savedUAttr.Size {
- panic(fs.ErrCorruption{fmt.Errorf("file size has changed for %s: previously %d, now %d", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.Size, uattr.Size)})
- }
- if env.ValidateFileTimestamp && uattr.ModificationTime != i.savedUAttr.ModificationTime {
- panic(fs.ErrCorruption{fmt.Errorf("file modification time has changed for %s: previously %v, now %v", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.ModificationTime, uattr.ModificationTime)})
- }
- i.savedUAttr = nil
- }
}
diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go
index 7221bc825..4c374681c 100644
--- a/pkg/sentry/fs/host/inode_test.go
+++ b/pkg/sentry/fs/host/inode_test.go
@@ -15,9 +15,6 @@
package host
import (
- "io/ioutil"
- "os"
- "path"
"syscall"
"testing"
@@ -25,69 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
)
-// TestMultipleReaddir verifies that multiple Readdir calls return the same
-// thing if they use different dir contexts.
-func TestMultipleReaddir(t *testing.T) {
- p, err := ioutil.TempDir("", "readdir")
- if err != nil {
- t.Fatalf("Failed to create test dir: %v", err)
- }
- defer os.RemoveAll(p)
-
- f, err := os.Create(path.Join(p, "a.txt"))
- if err != nil {
- t.Fatalf("Failed to create a.txt: %v", err)
- }
- f.Close()
-
- f, err = os.Create(path.Join(p, "b.txt"))
- if err != nil {
- t.Fatalf("Failed to create b.txt: %v", err)
- }
- f.Close()
-
- fd, err := open(nil, p)
- if err != nil {
- t.Fatalf("Failed to open %q: %v", p, err)
- }
- ctx := contexttest.Context(t)
- n, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false)
- if err != nil {
- t.Fatalf("Failed to create inode: %v", err)
- }
-
- dirent := fs.NewDirent(ctx, n, "readdir")
- openFile, err := n.GetFile(ctx, dirent, fs.FileFlags{Read: true})
- if err != nil {
- t.Fatalf("Failed to get file: %v", err)
- }
- defer openFile.DecRef()
-
- c1 := &fs.DirCtx{DirCursor: new(string)}
- if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c1, 0); err != nil {
- t.Fatalf("First Readdir failed: %v", err)
- }
-
- c2 := &fs.DirCtx{DirCursor: new(string)}
- if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c2, 0); err != nil {
- t.Errorf("Second Readdir failed: %v", err)
- }
-
- if _, ok := c1.DentAttrs()["a.txt"]; !ok {
- t.Errorf("want a.txt in first Readdir, got %v", c1.DentAttrs())
- }
- if _, ok := c1.DentAttrs()["b.txt"]; !ok {
- t.Errorf("want b.txt in first Readdir, got %v", c1.DentAttrs())
- }
-
- if _, ok := c2.DentAttrs()["a.txt"]; !ok {
- t.Errorf("want a.txt in second Readdir, got %v", c2.DentAttrs())
- }
- if _, ok := c2.DentAttrs()["b.txt"]; !ok {
- t.Errorf("want b.txt in second Readdir, got %v", c2.DentAttrs())
- }
-}
-
// TestCloseFD verifies fds will be closed.
func TestCloseFD(t *testing.T) {
var p [2]int
diff --git a/pkg/sentry/fs/host/ioctl_unsafe.go b/pkg/sentry/fs/host/ioctl_unsafe.go
index 271582e54..150ac8e19 100644
--- a/pkg/sentry/fs/host/ioctl_unsafe.go
+++ b/pkg/sentry/fs/host/ioctl_unsafe.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
+// LINT.IfChange
+
func ioctlGetTermios(fd int) (*linux.Termios, error) {
var t linux.Termios
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
@@ -54,3 +56,5 @@ func ioctlSetWinsize(fd int, w *linux.Winsize) error {
}
return nil
}
+
+// LINT.ThenChange(../../fsimpl/host/ioctl_unsafe.go)
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 3f218b4a7..cb91355ab 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// TTYFileOperations implements fs.FileOperations for a host file descriptor
// that wraps a TTY FD.
//
@@ -43,6 +45,7 @@ type TTYFileOperations struct {
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+ // termios contains the terminal attributes for this TTY.
termios linux.KernelTermios
}
@@ -357,3 +360,5 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
_ = pg.SendSignal(kernel.SignalInfoPriv(sig))
return kernel.ERESTARTSYS
}
+
+// LINT.ThenChange(../../fsimpl/host/tty.go)
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
index e37e687c6..1b0356930 100644
--- a/pkg/sentry/fs/host/util.go
+++ b/pkg/sentry/fs/host/util.go
@@ -16,7 +16,6 @@ package host
import (
"os"
- "path"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -28,45 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-func open(parent *inodeOperations, name string) (int, error) {
- if parent == nil && !path.IsAbs(name) {
- return -1, syserror.EINVAL
- }
- name = path.Clean(name)
-
- // Don't follow through symlinks.
- flags := syscall.O_NOFOLLOW
-
- if fd, err := openAt(parent, name, flags|syscall.O_RDWR, 0); err == nil {
- return fd, nil
- }
- // Retry as read-only.
- if fd, err := openAt(parent, name, flags|syscall.O_RDONLY, 0); err == nil {
- return fd, nil
- }
-
- // Retry as write-only.
- if fd, err := openAt(parent, name, flags|syscall.O_WRONLY, 0); err == nil {
- return fd, nil
- }
-
- // Retry as a symlink, by including O_PATH as an option.
- fd, err := openAt(parent, name, linux.O_PATH|flags, 0)
- if err == nil {
- return fd, nil
- }
-
- // Everything failed.
- return -1, err
-}
-
-func openAt(parent *inodeOperations, name string, flags int, perm linux.FileMode) (int, error) {
- if parent == nil {
- return syscall.Open(name, flags, uint32(perm))
- }
- return syscall.Openat(parent.fileState.FD(), name, flags, uint32(perm))
-}
-
func nodeType(s *syscall.Stat_t) fs.InodeType {
switch x := (s.Mode & syscall.S_IFMT); x {
case syscall.S_IFLNK:
@@ -107,51 +67,19 @@ func stableAttr(s *syscall.Stat_t) fs.StableAttr {
}
}
-func owner(mo *superOperations, s *syscall.Stat_t) fs.FileOwner {
- // User requested no translation, just return actual owner.
- if mo.dontTranslateOwnership {
- return fs.FileOwner{auth.KUID(s.Uid), auth.KGID(s.Gid)}
+func owner(s *syscall.Stat_t) fs.FileOwner {
+ return fs.FileOwner{
+ UID: auth.KUID(s.Uid),
+ GID: auth.KGID(s.Gid),
}
-
- // Show only IDs relevant to the sandboxed task. I.e. if we not own the
- // file, no sandboxed task can own the file. In that case, we
- // use OverflowID for UID, implying that the IDs are not mapped in the
- // "root" user namespace.
- //
- // E.g.
- // sandbox's host EUID/EGID is 1/1.
- // some_dir's host UID/GID is 2/1.
- // Task that mounted this fs has virtualized EUID/EGID 5/5.
- //
- // If you executed `ls -n` in the sandboxed task, it would show:
- // drwxwrxwrx [...] 65534 5 [...] some_dir
-
- // Files are owned by OverflowID by default.
- owner := fs.FileOwner{auth.KUID(auth.OverflowUID), auth.KGID(auth.OverflowGID)}
-
- // If we own file on host, let mounting task's initial EUID own
- // the file.
- if s.Uid == hostUID {
- owner.UID = mo.mounter.UID
- }
-
- // If our group matches file's group, make file's group match
- // the mounting task's initial EGID.
- for _, gid := range hostGIDs {
- if s.Gid == gid {
- owner.GID = mo.mounter.GID
- break
- }
- }
- return owner
}
-func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr {
+func unstableAttr(s *syscall.Stat_t) fs.UnstableAttr {
return fs.UnstableAttr{
Size: s.Size,
Usage: s.Blocks * 512,
Perms: fs.FilePermsFromMode(linux.FileMode(s.Mode)),
- Owner: owner(mo, s),
+ Owner: owner(s),
AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec),
ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec),
StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec),
@@ -165,6 +93,8 @@ type dirInfo struct {
bufp int // location of next record in buf.
}
+// LINT.IfChange
+
// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
@@ -177,6 +107,8 @@ func isBlockError(err error) bool {
return false
}
+// LINT.ThenChange(../../fsimpl/host/util.go)
+
func hostEffectiveKIDs() (uint32, []uint32, error) {
gids, err := os.Getgroups()
if err != nil {
diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go
index 3ab36b088..23bd35d64 100644
--- a/pkg/sentry/fs/host/util_unsafe.go
+++ b/pkg/sentry/fs/host/util_unsafe.go
@@ -26,26 +26,6 @@ import (
// NulByte is a single NUL byte. It is passed to readlinkat as an empty string.
var NulByte byte = '\x00'
-func createLink(fd int, name string, linkName string) error {
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return err
- }
- linkNamePtr, err := syscall.BytePtrFromString(linkName)
- if err != nil {
- return err
- }
- _, _, errno := syscall.Syscall(
- syscall.SYS_SYMLINKAT,
- uintptr(unsafe.Pointer(namePtr)),
- uintptr(fd),
- uintptr(unsafe.Pointer(linkNamePtr)))
- if errno != 0 {
- return errno
- }
- return nil
-}
-
func readLink(fd int) (string, error) {
// Buffer sizing copied from os.Readlink.
for l := 128; ; l *= 2 {
@@ -66,27 +46,6 @@ func readLink(fd int) (string, error) {
}
}
-func unlinkAt(fd int, name string, dir bool) error {
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return err
- }
- var flags uintptr
- if dir {
- flags = linux.AT_REMOVEDIR
- }
- _, _, errno := syscall.Syscall(
- syscall.SYS_UNLINKAT,
- uintptr(fd),
- uintptr(unsafe.Pointer(namePtr)),
- flags,
- )
- if errno != 0 {
- return errno
- }
- return nil
-}
-
func timespecFromTimestamp(t ktime.Time, omit, setSysTime bool) syscall.Timespec {
if omit {
return syscall.Timespec{0, linux.UTIME_OMIT}
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index 928c90aa0..e3a715c1f 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -143,7 +143,10 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
}
var writeLen int64
- for event := i.events.Front(); event != nil; event = event.Next() {
+ for it := i.events.Front(); it != nil; {
+ event := it
+ it = it.Next()
+
// Does the buffer have enough remaining space to hold the event we're
// about to write out?
if dst.NumBytes() < int64(event.sizeOf()) {
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index c7981f66e..b414ddaee 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -273,19 +273,6 @@ func (mns *MountNamespace) DecRef() {
mns.DecRefWithDestructor(mns.destroy)
}
-// Freeze freezes the entire mount tree.
-func (mns *MountNamespace) Freeze() {
- mns.mu.Lock()
- defer mns.mu.Unlock()
-
- // We only want to freeze Dirents with active references, not Dirents referenced
- // by a mount's MountSource.
- mns.flushMountSourceRefsLocked()
-
- // Freeze the entire shebang.
- mns.root.Freeze()
-}
-
// withMountLocked prevents further walks to `node`, because `node` is about to
// be a mount point.
func (mns *MountNamespace) withMountLocked(node *Dirent, fn func() error) error {
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 95d5817ff..bd18177d4 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -40,47 +40,48 @@ import (
// LINT.IfChange
-// newNet creates a new proc net entry.
-func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
+// newNetDir creates a new proc net entry.
+func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ k := t.Kernel()
+
var contents map[string]*fs.Inode
- // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
- // network namespace of the calling process. We should make this per-process,
- // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
- if s := p.k.RootNetworkNamespace().Stack(); s != nil {
+ if s := t.NetworkNamespace().Stack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
contents = map[string]*fs.Inode{
- "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
- "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
+ "dev": seqfile.NewSeqFileInode(t, &netDev{s: s}, msrc),
+ "snmp": seqfile.NewSeqFileInode(t, &netSnmp{s: s}, msrc),
// The following files are simple stubs until they are
// implemented in netstack, if the file contains a
// header the stub is just the header otherwise it is
// an empty file.
- "arp": newStaticProcInode(ctx, msrc, []byte("IP address HW type Flags HW address Mask Device\n")),
+ "arp": newStaticProcInode(t, msrc, []byte("IP address HW type Flags HW address Mask Device\n")),
- "netlink": newStaticProcInode(ctx, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")),
- "netstat": newStaticProcInode(ctx, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")),
- "packet": newStaticProcInode(ctx, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")),
- "protocols": newStaticProcInode(ctx, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")),
+ "netlink": newStaticProcInode(t, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")),
+ "netstat": newStaticProcInode(t, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")),
+ "packet": newStaticProcInode(t, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")),
+ "protocols": newStaticProcInode(t, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")),
// Linux sets psched values to: nsec per usec, psched
// tick in ns, 1000000, high res timer ticks per sec
// (ClockGetres returns 1ns resolution).
- "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
- "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function\n")),
- "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc),
- "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc),
- "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc),
- "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
+ "psched": newStaticProcInode(t, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
+ "ptype": newStaticProcInode(t, msrc, []byte("Type Device Function\n")),
+ "route": seqfile.NewSeqFileInode(t, &netRoute{s: s}, msrc),
+ "tcp": seqfile.NewSeqFileInode(t, &netTCP{k: k}, msrc),
+ "udp": seqfile.NewSeqFileInode(t, &netUDP{k: k}, msrc),
+ "unix": seqfile.NewSeqFileInode(t, &netUnix{k: k}, msrc),
}
if s.SupportsIPv6() {
- contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc)
- contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte(""))
- contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc)
- contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"))
+ contents["if_inet6"] = seqfile.NewSeqFileInode(t, &ifinet6{s: s}, msrc)
+ contents["ipv6_route"] = newStaticProcInode(t, msrc, []byte(""))
+ contents["tcp6"] = seqfile.NewSeqFileInode(t, &netTCP6{k: k}, msrc)
+ contents["udp6"] = newStaticProcInode(t, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"))
}
}
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+ d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
}
// ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6.
@@ -837,4 +838,4 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
return data, 0
}
-// LINT.ThenChange(../../fsimpl/proc/tasks_net.go)
+// LINT.ThenChange(../../fsimpl/proc/task_net.go)
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index c8abb5052..c659224a7 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -70,6 +70,7 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
"loadavg": seqfile.NewSeqFileInode(ctx, &loadavgData{}, msrc),
"meminfo": seqfile.NewSeqFileInode(ctx, &meminfoData{k}, msrc),
"mounts": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/mounts"), msrc, fs.Symlink, nil),
+ "net": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/net"), msrc, fs.Symlink, nil),
"self": newSelf(ctx, pidns, msrc),
"stat": seqfile.NewSeqFileInode(ctx, &statData{k}, msrc),
"thread-self": newThreadSelf(ctx, pidns, msrc),
@@ -86,7 +87,6 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
}
// Add more contents that need proc to be initialized.
- p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc))
return newProcInode(ctx, p, msrc, fs.SpecialDirectory, nil), nil
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 8ab8d8a02..d6c5dd2c1 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -72,24 +72,27 @@ var _ fs.InodeOperations = (*taskDir)(nil)
// newTaskDir creates a new proc task entry.
func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
contents := map[string]*fs.Inode{
- "auxv": newAuxvec(t, msrc),
- "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
- "comm": newComm(t, msrc),
- "environ": newExecArgInode(t, msrc, environExecArg),
- "exe": newExe(t, msrc),
- "fd": newFdDir(t, msrc),
- "fdinfo": newFdInfoDir(t, msrc),
- "gid_map": newGIDMap(t, msrc),
- "io": newIO(t, msrc, isThreadGroup),
- "maps": newMaps(t, msrc),
- "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
- "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
- "ns": newNamespaceDir(t, msrc),
- "smaps": newSmaps(t, msrc),
- "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
- "statm": newStatm(t, msrc),
- "status": newStatus(t, msrc, p.pidns),
- "uid_map": newUIDMap(t, msrc),
+ "auxv": newAuxvec(t, msrc),
+ "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ "comm": newComm(t, msrc),
+ "environ": newExecArgInode(t, msrc, environExecArg),
+ "exe": newExe(t, msrc),
+ "fd": newFdDir(t, msrc),
+ "fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newGIDMap(t, msrc),
+ "io": newIO(t, msrc, isThreadGroup),
+ "maps": newMaps(t, msrc),
+ "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
+ "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "net": newNetDir(t, msrc),
+ "ns": newNamespaceDir(t, msrc),
+ "oom_score": newOOMScore(t, msrc),
+ "oom_score_adj": newOOMScoreAdj(t, msrc),
+ "smaps": newSmaps(t, msrc),
+ "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
+ "statm": newStatm(t, msrc),
+ "status": newStatus(t, msrc, p.pidns),
+ "uid_map": newUIDMap(t, msrc),
}
if isThreadGroup {
contents["task"] = p.newSubtasks(t, msrc)
@@ -796,4 +799,95 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
return int64(n), err
}
+// newOOMScore returns a oom_score file. It is a stub that always returns 0.
+// TODO(gvisor.dev/issue/1967)
+func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newStaticProcInode(t, msrc, []byte("0\n"))
+}
+
+// oomScoreAdj is a file containing the oom_score adjustment for a task.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// +stateify savable
+type oomScoreAdjFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+// newOOMScoreAdj returns a oom_score_adj file.
+func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ i := &oomScoreAdj{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, i, msrc, fs.SpecialFile, t)
+}
+
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*oomScoreAdj) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (o *oomScoreAdj) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &oomScoreAdjFile{t: o.t}), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *oomScoreAdjFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "%d\n", f.t.OOMScoreAdj())
+ if offset >= int64(buf.Len()) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, buf.Bytes()[offset:])
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := f.t.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
// LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index 6aa17bfc1..79b808359 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -115,8 +115,6 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
//
// remainingTraversals is not configurable in VFS2, all callers are using the
// default anyways.
-//
-// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission.
func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index e05429d41..8497be615 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -255,6 +256,15 @@ func (fs *filesystem) statTo(stat *linux.Statfs) {
// TODO(b/134676337): Set Statfs.Flags and Statfs.FSID.
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return inode.checkPermissions(rp.Credentials(), ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
vfsd, inode, err := fs.walk(rp, false)
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 6962083f5..a39a37318 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -186,7 +186,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
}
func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
- return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID())
+ return vfs.GenericCheckPermissions(creds, ats, in.diskInode.Mode(), in.diskInode.UID(), in.diskInode.GID())
}
// statTo writes the statx fields to the output parameter.
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 5cfb0dc4c..1e43df9ec 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -118,7 +119,7 @@ func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
@@ -313,7 +314,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
if parent.isDeleted() {
@@ -377,7 +378,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -453,6 +454,9 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
}
if fs.opts.interop != InteropModeShared {
parent.touchCMtime(ctx)
+ if dir {
+ parent.decLinks()
+ }
parent.cacheNegativeChildLocked(name)
parent.dirents = nil
}
@@ -499,6 +503,18 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) {
putDentrySlice(*ds)
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
var ds *[]*dentry
@@ -512,7 +528,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -556,8 +572,13 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error {
creds := rp.Credentials()
- _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- return err
+ if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil {
+ return err
+ }
+ if fs.opts.interop != InteropModeShared {
+ parent.incLinks()
+ }
+ return nil
})
}
@@ -603,7 +624,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Determine whether or not we need to create a file.
@@ -640,7 +661,7 @@ afterTrailingSymlink:
// Preconditions: fs.renameMu must be locked.
func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
- if err := d.checkPermissions(rp.Credentials(), ats, d.isDir()); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
mnt := rp.Mount()
@@ -701,7 +722,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
- if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if d.isDeleted() {
@@ -863,7 +884,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return err
}
}
- if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
vfsObj := rp.VirtualFilesystem()
@@ -883,7 +904,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return syserror.EINVAL
}
if oldParent != newParent {
- if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return err
}
}
@@ -894,7 +915,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if oldParent != newParent {
- if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
newParent.dirMu.Lock()
@@ -949,6 +970,10 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.dirents = nil
delete(newParent.negativeChildren, newName)
newParent.dirents = nil
+ if renamed.isDir() {
+ oldParent.decLinks()
+ newParent.incLinks()
+ }
}
vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, &newParent.vfsd, newName, replacedVFSD)
return nil
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index d00850e25..cf276a417 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -485,6 +485,11 @@ type dentry struct {
// locked to mutate it).
size uint64
+ // nlink counts the number of hard links to this dentry. It's updated and
+ // accessed using atomic operations. It's not protected by metadataMu like the
+ // other metadata fields.
+ nlink uint32
+
mapsMu sync.Mutex
// If this dentry represents a regular file, mappings tracks mappings of
@@ -604,6 +609,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
if mask.BTime {
d.btime = dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds)
}
+ if mask.NLink {
+ d.nlink = uint32(attr.NLink)
+ }
d.vfsd.Init(d)
fs.syncMu.Lock()
@@ -645,6 +653,9 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
if mask.BTime {
atomic.StoreInt64(&d.btime, dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds))
}
+ if mask.NLink {
+ atomic.StoreUint32(&d.nlink, uint32(attr.NLink))
+ }
if mask.Size {
d.dataMu.Lock()
atomic.StoreUint64(&d.size, attr.Size)
@@ -687,10 +698,7 @@ func (d *dentry) fileType() uint32 {
func (d *dentry) statTo(stat *linux.Statx) {
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME
stat.Blksize = atomic.LoadUint32(&d.blockSize)
- stat.Nlink = 1
- if d.isDir() {
- stat.Nlink = 2
- }
+ stat.Nlink = atomic.LoadUint32(&d.nlink)
stat.UID = atomic.LoadUint32(&d.uid)
stat.GID = atomic.LoadUint32(&d.gid)
stat.Mode = uint16(atomic.LoadUint32(&d.mode))
@@ -703,7 +711,7 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
- // TODO(jamieliu): device number
+ // TODO(gvisor.dev/issue/1198): device number
}
func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
@@ -713,7 +721,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
if err := mnt.CheckBeginWrite(); err != nil {
@@ -835,8 +844,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
return nil
}
-func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
- return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&d.mode))&0777, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
// IncRef implements vfs.DentryImpl.IncRef.
@@ -1045,13 +1054,13 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// using the old file descriptor, preventing us from safely
// closing it. We could handle this by invalidating existing
// memmap.Translations, but this is expensive. Instead, use
- // dup2() to make the old file descriptor refer to the new file
+ // dup3 to make the old file descriptor refer to the new file
// description, then close the new file descriptor (which is no
// longer needed). Racing callers may use the old or new file
// description, but this doesn't matter since they refer to the
// same file (unless d.fs.opts.overlayfsStaleRead is true,
// which we handle separately).
- if err := syscall.Dup2(int(h.fd), int(d.handle.fd)); err != nil {
+ if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil {
d.handleMu.Unlock()
ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err)
h.close(ctx)
@@ -1094,6 +1103,26 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
return nil
}
+// incLinks increments link count.
+//
+// Preconditions: d.nlink != 0 && d.nlink < math.MaxUint32.
+func (d *dentry) incLinks() {
+ v := atomic.AddUint32(&d.nlink, 1)
+ if v < 2 {
+ panic(fmt.Sprintf("dentry.nlink is invalid (was 0 or overflowed): %d", v))
+ }
+}
+
+// decLinks decrements link count.
+//
+// Preconditions: d.nlink > 1.
+func (d *dentry) decLinks() {
+ v := atomic.AddUint32(&d.nlink, ^uint32(0))
+ if v == 0 {
+ panic(fmt.Sprintf("dentry.nlink must be greater than 0: %d", v))
+ }
+}
+
// fileDescription is embedded by gofer implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
@@ -1112,7 +1141,8 @@ func (fd *fileDescription) dentry() *dentry {
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
d := fd.dentry()
- if d.fs.opts.interop == InteropModeShared && opts.Mask&(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE|linux.STATX_BLOCKS|linux.STATX_BTIME) != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
+ const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
+ if d.fs.opts.interop == InteropModeShared && opts.Mask&(validMask) != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
// TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
// available?
if err := d.updateFromGetattr(ctx); err != nil {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 54c1031a7..3593eb1d5 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -126,6 +126,11 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
if opts.Flags != 0 {
return 0, syserror.EOPNOTSUPP
}
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
d := fd.dentry()
d.metadataMu.Lock()
@@ -361,8 +366,15 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro
rw.d.handleMu.RLock()
if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off)
- rw.d.handleMu.RUnlock()
rw.off += n
+ rw.d.dataMu.Lock()
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 08c691c47..274f7346f 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -107,6 +107,14 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, syserror.EOPNOTSUPP
}
+ if fd.dentry().fileType() == linux.S_IFREG {
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
+ }
+
// Do a buffered write. See rationale in PRead.
if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
d.touchCMtime(ctx)
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
new file mode 100644
index 000000000..82e1fb74b
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "host",
+ srcs = [
+ "host.go",
+ "ioctl_unsafe.go",
+ "tty.go",
+ "util.go",
+ "util_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
new file mode 100644
index 000000000..a54985ef5
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -0,0 +1,616 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host provides a filesystem implementation for host files imported as
+// file descriptors.
+package host
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+}
+
+// NewMount returns a new disconnected mount in vfsObj that may be passed to ImportFD.
+func NewMount(vfsObj *vfs.VirtualFilesystem) (*vfs.Mount, error) {
+ fs := &filesystem{}
+ fs.Init(vfsObj)
+ vfsfs := fs.VFSFilesystem()
+ // NewDisconnectedMount will take an additional reference on vfsfs.
+ defer vfsfs.DecRef()
+ return vfsObj.NewDisconnectedMount(vfsfs, nil, &vfs.MountOptions{})
+}
+
+// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
+func ImportFD(mnt *vfs.Mount, hostFD int, ownerUID auth.KUID, ownerGID auth.KGID, isTTY bool) (*vfs.FileDescription, error) {
+ fs, ok := mnt.Filesystem().Impl().(*kernfs.Filesystem)
+ if !ok {
+ return nil, fmt.Errorf("can't import host FDs into filesystems of type %T", mnt.Filesystem().Impl())
+ }
+
+ // Retrieve metadata.
+ var s syscall.Stat_t
+ if err := syscall.Fstat(hostFD, &s); err != nil {
+ return nil, err
+ }
+
+ fileMode := linux.FileMode(s.Mode)
+ fileType := fileMode.FileType()
+ // Pipes, character devices, and sockets.
+ isStream := fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
+
+ i := &inode{
+ hostFD: hostFD,
+ isStream: isStream,
+ isTTY: isTTY,
+ canMap: canMap(uint32(fileType)),
+ ino: fs.NextIno(),
+ mode: fileMode,
+ uid: ownerUID,
+ gid: ownerGID,
+ // For simplicity, set offset to 0. Technically, we should
+ // only set to 0 on files that are not seekable (sockets, pipes, etc.),
+ // and use the offset from the host fd otherwise.
+ offset: 0,
+ }
+
+ // These files can't be memory mapped, assert this.
+ if i.isStream && i.canMap {
+ panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
+ }
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+ // i.open will take a reference on d.
+ defer d.DecRef()
+
+ return i.open(d.VFSDentry(), mnt)
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ // When the reference count reaches zero, the host fd is closed.
+ refs.AtomicRefCount
+
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ hostFD int
+
+ // isStream is true if the host fd points to a file representing a stream,
+ // e.g. a socket or a pipe. Such files are not seekable and can return
+ // EWOULDBLOCK for I/O operations.
+ //
+ // This field is initialized at creation time and is immutable.
+ isStream bool
+
+ // isTTY is true if this file represents a TTY.
+ //
+ // This field is initialized at creation time and is immutable.
+ isTTY bool
+
+ // canMap specifies whether we allow the file to be memory mapped.
+ //
+ // This field is initialized at creation time and is immutable.
+ canMap bool
+
+ // ino is an inode number unique within this filesystem.
+ //
+ // This field is initialized at creation time and is immutable.
+ ino uint64
+
+ // TODO(gvisor.dev/issue/1672): protect mode, uid, and gid with mutex.
+
+ // mode is the file mode of this inode. Note that this value may become out
+ // of date if the mode is changed on the host, e.g. with chmod.
+ mode linux.FileMode
+
+ // uid and gid of the file owner. Note that these refer to the owner of the
+ // file created on import, not the fd on the host.
+ uid auth.KUID
+ gid auth.KGID
+
+ // offsetMu protects offset.
+ offsetMu sync.Mutex
+
+ // offset specifies the current file offset.
+ offset int64
+}
+
+// Note that these flags may become out of date, since they can be modified
+// on the host, e.g. with fcntl.
+func fileFlagsFromHostFD(fd int) (int, error) {
+ flags, err := unix.FcntlInt(uintptr(fd), syscall.F_GETFL, 0)
+ if err != nil {
+ log.Warningf("Failed to get file flags for donated FD %d: %v", fd, err)
+ return 0, err
+ }
+ // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags.
+ flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
+ return flags, nil
+}
+
+// CheckPermissions implements kernfs.Inode.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, i.mode, i.uid, i.gid)
+}
+
+// Mode implements kernfs.Inode.
+func (i *inode) Mode() linux.FileMode {
+ return i.mode
+}
+
+// Stat implements kernfs.Inode.
+func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ if opts.Mask&linux.STATX__RESERVED != 0 {
+ return linux.Statx{}, syserror.EINVAL
+ }
+ if opts.Sync&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
+ return linux.Statx{}, syserror.EINVAL
+ }
+
+ // Limit our host call only to known flags.
+ mask := opts.Mask & linux.STATX_ALL
+ var s unix.Statx_t
+ err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s)
+ // Fallback to fstat(2), if statx(2) is not supported on the host.
+ //
+ // TODO(b/151263641): Remove fallback.
+ if err == syserror.ENOSYS {
+ return i.fstat(opts)
+ } else if err != nil {
+ return linux.Statx{}, err
+ }
+
+ ls := linux.Statx{Mask: mask}
+ // Unconditionally fill blksize, attributes, and device numbers, as indicated
+ // by /include/uapi/linux/stat.h.
+ //
+ // RdevMajor/RdevMinor are left as zero, so as not to expose host device
+ // numbers.
+ //
+ // TODO(gvisor.dev/issue/1672): Use kernfs-specific, internally defined
+ // device numbers. If we use the device number from the host, it may collide
+ // with another sentry-internal device number. We handle device/inode
+ // numbers without relying on the host to prevent collisions.
+ ls.Blksize = s.Blksize
+ ls.Attributes = s.Attributes
+ ls.AttributesMask = s.Attributes_mask
+
+ if mask|linux.STATX_TYPE != 0 {
+ ls.Mode |= s.Mode & linux.S_IFMT
+ }
+ if mask|linux.STATX_MODE != 0 {
+ ls.Mode |= s.Mode &^ linux.S_IFMT
+ }
+ if mask|linux.STATX_NLINK != 0 {
+ ls.Nlink = s.Nlink
+ }
+ if mask|linux.STATX_ATIME != 0 {
+ ls.Atime = unixToLinuxStatxTimestamp(s.Atime)
+ }
+ if mask|linux.STATX_BTIME != 0 {
+ ls.Btime = unixToLinuxStatxTimestamp(s.Btime)
+ }
+ if mask|linux.STATX_CTIME != 0 {
+ ls.Ctime = unixToLinuxStatxTimestamp(s.Ctime)
+ }
+ if mask|linux.STATX_MTIME != 0 {
+ ls.Mtime = unixToLinuxStatxTimestamp(s.Mtime)
+ }
+ if mask|linux.STATX_SIZE != 0 {
+ ls.Size = s.Size
+ }
+ if mask|linux.STATX_BLOCKS != 0 {
+ ls.Blocks = s.Blocks
+ }
+
+ // Use our own internal inode number and file owner.
+ if mask|linux.STATX_INO != 0 {
+ ls.Ino = i.ino
+ }
+ if mask|linux.STATX_UID != 0 {
+ ls.UID = uint32(i.uid)
+ }
+ if mask|linux.STATX_GID != 0 {
+ ls.GID = uint32(i.gid)
+ }
+
+ return ls, nil
+}
+
+// fstat is a best-effort fallback for inode.Stat() if the host does not
+// support statx(2).
+//
+// We ignore the mask and sync flags in opts and simply supply
+// STATX_BASIC_STATS, as fstat(2) itself does not allow the specification
+// of a mask or sync flags. fstat(2) does not provide any metadata
+// equivalent to Statx.Attributes, Statx.AttributesMask, or Statx.Btime, so
+// those fields remain empty.
+func (i *inode) fstat(opts vfs.StatOptions) (linux.Statx, error) {
+ var s unix.Stat_t
+ if err := unix.Fstat(i.hostFD, &s); err != nil {
+ return linux.Statx{}, err
+ }
+
+ // Note that rdev numbers are left as 0; do not expose host device numbers.
+ ls := linux.Statx{
+ Mask: linux.STATX_BASIC_STATS,
+ Blksize: uint32(s.Blksize),
+ Nlink: uint32(s.Nlink),
+ Mode: uint16(s.Mode),
+ Size: uint64(s.Size),
+ Blocks: uint64(s.Blocks),
+ Atime: timespecToStatxTimestamp(s.Atim),
+ Ctime: timespecToStatxTimestamp(s.Ctim),
+ Mtime: timespecToStatxTimestamp(s.Mtim),
+ }
+
+ // Use our own internal inode number and file owner.
+ //
+ // TODO(gvisor.dev/issue/1672): Use a kernfs-specific device number as well.
+ // If we use the device number from the host, it may collide with another
+ // sentry-internal device number. We handle device/inode numbers without
+ // relying on the host to prevent collisions.
+ ls.Ino = i.ino
+ ls.UID = uint32(i.uid)
+ ls.GID = uint32(i.gid)
+
+ return ls, nil
+}
+
+// SetStat implements kernfs.Inode.
+func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ s := opts.Stat
+
+ m := s.Mask
+ if m == 0 {
+ return nil
+ }
+ if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ return syserror.EPERM
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &s, i.Mode(), i.uid, i.gid); err != nil {
+ return err
+ }
+
+ if m&linux.STATX_MODE != 0 {
+ if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil {
+ return err
+ }
+ i.mode = linux.FileMode(s.Mode)
+ }
+ if m&linux.STATX_SIZE != 0 {
+ if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
+ return err
+ }
+ }
+ if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ ts := [2]syscall.Timespec{
+ toTimespec(s.Atime, m&linux.STATX_ATIME == 0),
+ toTimespec(s.Mtime, m&linux.STATX_MTIME == 0),
+ }
+ if err := setTimestamps(i.hostFD, &ts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// DecRef implements kernfs.Inode.
+func (i *inode) DecRef() {
+ i.AtomicRefCount.DecRefWithDestructor(i.Destroy)
+}
+
+// Destroy implements kernfs.Inode.
+func (i *inode) Destroy() {
+ if err := unix.Close(i.hostFD); err != nil {
+ log.Warningf("failed to close host fd %d: %v", i.hostFD, err)
+ }
+}
+
+// Open implements kernfs.Inode.
+func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return i.open(vfsd, rp.Mount())
+}
+
+func (i *inode) open(d *vfs.Dentry, mnt *vfs.Mount) (*vfs.FileDescription, error) {
+ fileType := i.mode.FileType()
+ if fileType == syscall.S_IFSOCK {
+ if i.isTTY {
+ return nil, errors.New("cannot use host socket as TTY")
+ }
+ // TODO(gvisor.dev/issue/1672): support importing sockets.
+ return nil, errors.New("importing host sockets not supported")
+ }
+
+ // TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that
+ // we don't allow importing arbitrary file types without proper support.
+ var (
+ vfsfd *vfs.FileDescription
+ fdImpl vfs.FileDescriptionImpl
+ )
+ if i.isTTY {
+ fd := &ttyFD{
+ fileDescription: fileDescription{inode: i},
+ termios: linux.DefaultSlaveTermios,
+ }
+ vfsfd = &fd.vfsfd
+ fdImpl = fd
+ } else {
+ // For simplicity, set offset to 0. Technically, we should
+ // only set to 0 on files that are not seekable (sockets, pipes, etc.),
+ // and use the offset from the host fd otherwise.
+ fd := &fileDescription{inode: i}
+ vfsfd = &fd.vfsfd
+ fdImpl = fd
+ }
+
+ flags, err := fileFlagsFromHostFD(i.hostFD)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := vfsfd.Init(fdImpl, uint32(flags), mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+}
+
+// fileDescription is embedded by host fd implementations of FileDescriptionImpl.
+//
+// TODO(gvisor.dev/issue/1672): Implement Waitable interface.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+
+ // inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but
+ // cached to reduce indirections and casting. fileDescription does not hold
+ // a reference on the inode through the inode field (since one is already
+ // held via the Dentry).
+ //
+ // inode is immutable after fileDescription creation.
+ inode *inode
+}
+
+// SetStat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ return f.inode.SetStat(ctx, nil, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ return f.inode.Stat(nil, opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Release() {
+ // noop
+}
+
+// PRead implements FileDescriptionImpl.
+func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
+ if i.isStream {
+ return 0, syserror.ESPIPE
+ }
+
+ return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags)
+}
+
+// Read implements FileDescriptionImpl.
+func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
+ if i.isStream {
+ n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
+ if isBlockError(err) {
+ // If we got any data at all, return it as a "completed" partial read
+ // rather than retrying until complete.
+ if n != 0 {
+ err = nil
+ } else {
+ err = syserror.ErrWouldBlock
+ }
+ }
+ return n, err
+ }
+ // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so.
+ i.offsetMu.Lock()
+ n, err := readFromHostFD(ctx, i.hostFD, dst, i.offset, opts.Flags)
+ i.offset += n
+ i.offsetMu.Unlock()
+ return n, err
+}
+
+func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ // TODO(gvisor.dev/issue/1672): Support select preadv2 flags.
+ if flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ var reader safemem.Reader
+ if offset == -1 {
+ reader = safemem.FromIOReader{fd.NewReadWriter(hostFD)}
+ } else {
+ reader = safemem.FromVecReaderFunc{
+ func(srcs [][]byte) (int64, error) {
+ n, err := unix.Preadv(hostFD, srcs, offset)
+ return int64(n), err
+ },
+ }
+ }
+ n, err := dst.CopyOutFrom(ctx, reader)
+ return int64(n), err
+}
+
+// PWrite implements FileDescriptionImpl.
+func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ i := f.inode
+ // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
+ if i.isStream {
+ return 0, syserror.ESPIPE
+ }
+
+ return writeToHostFD(ctx, i.hostFD, src, offset, opts.Flags)
+}
+
+// Write implements FileDescriptionImpl.
+func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ i := f.inode
+ // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
+ if i.isStream {
+ n, err := writeToHostFD(ctx, i.hostFD, src, -1, opts.Flags)
+ if isBlockError(err) {
+ err = syserror.ErrWouldBlock
+ }
+ return n, err
+ }
+ // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so.
+ // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file.
+ i.offsetMu.Lock()
+ n, err := writeToHostFD(ctx, i.hostFD, src, i.offset, opts.Flags)
+ i.offset += n
+ i.offsetMu.Unlock()
+ return n, err
+}
+
+func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ // TODO(gvisor.dev/issue/1672): Support select pwritev2 flags.
+ if flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ var writer safemem.Writer
+ if offset == -1 {
+ writer = safemem.FromIOWriter{fd.NewReadWriter(hostFD)}
+ } else {
+ writer = safemem.FromVecWriterFunc{
+ func(srcs [][]byte) (int64, error) {
+ n, err := unix.Pwritev(hostFD, srcs, offset)
+ return int64(n), err
+ },
+ }
+ }
+ n, err := src.CopyInTo(ctx, writer)
+ return int64(n), err
+}
+
+// Seek implements FileDescriptionImpl.
+//
+// Note that we do not support seeking on directories, since we do not even
+// allow directory fds to be imported at all.
+func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (int64, error) {
+ i := f.inode
+ // TODO(b/34716638): Some char devices do support seeking, e.g. /dev/null.
+ if i.isStream {
+ return 0, syserror.ESPIPE
+ }
+
+ i.offsetMu.Lock()
+ defer i.offsetMu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return i.offset, syserror.EINVAL
+ }
+ i.offset = offset
+
+ case linux.SEEK_CUR:
+ // Check for overflow. Note that underflow cannot occur, since i.offset >= 0.
+ if offset > math.MaxInt64-i.offset {
+ return i.offset, syserror.EOVERFLOW
+ }
+ if i.offset+offset < 0 {
+ return i.offset, syserror.EINVAL
+ }
+ i.offset += offset
+
+ case linux.SEEK_END:
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return i.offset, err
+ }
+ size := s.Size
+
+ // Check for overflow. Note that underflow cannot occur, since size >= 0.
+ if offset > math.MaxInt64-size {
+ return i.offset, syserror.EOVERFLOW
+ }
+ if size+offset < 0 {
+ return i.offset, syserror.EINVAL
+ }
+ i.offset = size + offset
+
+ case linux.SEEK_DATA, linux.SEEK_HOLE:
+ // Modifying the offset in the host file table should not matter, since
+ // this is the only place where we use it.
+ //
+ // For reading and writing, we always rely on our internal offset.
+ n, err := unix.Seek(i.hostFD, offset, int(whence))
+ if err != nil {
+ return i.offset, err
+ }
+ i.offset = n
+
+ default:
+ // Invalid whence.
+ return i.offset, syserror.EINVAL
+ }
+
+ return i.offset, nil
+}
+
+// Sync implements FileDescriptionImpl.
+func (f *fileDescription) Sync(context.Context) error {
+ // TODO(gvisor.dev/issue/1672): Currently we do not support the SyncData optimization, so we always sync everything.
+ return unix.Fsync(f.inode.hostFD)
+}
+
+// ConfigureMMap implements FileDescriptionImpl.
+func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
+ if !f.inode.canMap {
+ return syserror.ENODEV
+ }
+ // TODO(gvisor.dev/issue/1672): Implement ConfigureMMap and Mappable interface.
+ return syserror.ENODEV
+}
diff --git a/pkg/sentry/fsimpl/host/ioctl_unsafe.go b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
new file mode 100644
index 000000000..0983bf7d8
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+func ioctlGetTermios(fd int) (*linux.Termios, error) {
+ var t linux.Termios
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &t, nil
+}
+
+func ioctlSetTermios(fd int, req uint64, t *linux.Termios) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(unsafe.Pointer(t)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func ioctlGetWinsize(fd int) (*linux.Winsize, error) {
+ var w linux.Winsize
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCGWINSZ, uintptr(unsafe.Pointer(&w)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &w, nil
+}
+
+func ioctlSetWinsize(fd int, w *linux.Winsize) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCSWINSZ, uintptr(unsafe.Pointer(w)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
new file mode 100644
index 000000000..8936afb06
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -0,0 +1,379 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ttyFD implements vfs.FileDescriptionImpl for a host file descriptor
+// that wraps a TTY FD.
+type ttyFD struct {
+ fileDescription
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // session is the session attached to this ttyFD.
+ session *kernel.Session
+
+ // fgProcessGroup is the foreground process group that is currently
+ // connected to this TTY.
+ fgProcessGroup *kernel.ProcessGroup
+
+ // termios contains the terminal attributes for this TTY.
+ termios linux.KernelTermios
+}
+
+// InitForegroundProcessGroup sets the foreground process group and session for
+// the TTY. This should only be called once, after the foreground process group
+// has been created, but before it has started running.
+func (t *ttyFD) InitForegroundProcessGroup(pg *kernel.ProcessGroup) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.fgProcessGroup != nil {
+ panic("foreground process group is already set")
+ }
+ t.fgProcessGroup = pg
+ t.session = pg.Session()
+}
+
+// ForegroundProcessGroup returns the foreground process for the TTY.
+func (t *ttyFD) ForegroundProcessGroup() *kernel.ProcessGroup {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.fgProcessGroup
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *ttyFD) Release() {
+ t.mu.Lock()
+ t.fgProcessGroup = nil
+ t.mu.Unlock()
+
+ t.fileDescription.Release()
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.PRead(ctx, dst, offset, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *ttyFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.Read(ctx, dst, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (t *ttyFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.PWrite(ctx, src, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (t *ttyFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.Write(ctx, src, opts)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (t *ttyFD) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Ignore arg[0]. This is the real FD:
+ fd := t.inode.hostFD
+ ioctl := args[1].Uint64()
+ switch ioctl {
+ case linux.TCGETS:
+ termios, err := ioctlGetTermios(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TCSETS, linux.TCSETSW, linux.TCSETSF:
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+
+ var termios linux.Termios
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
+ return 0, err
+
+ case linux.TIOCGPGRP:
+ // Args: pid_t *argp
+ // When successful, equivalent to *argp = tcgetpgrp(fd).
+ // Get the process group ID of the foreground process group on this
+ // terminal.
+
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Map the ProcessGroup into a ProcessGroupID in the task's PID namespace.
+ pgID := pidns.IDOfProcessGroup(t.fgProcessGroup)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSPGRP:
+ // Args: const pid_t *argp
+ // Equivalent to tcsetpgrp(fd, *argp).
+ // Set the foreground process group ID of this terminal.
+
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check that we are allowed to set the process group.
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from tty_check_change()
+ // to -ENOTTY.
+ if err == syserror.EIO {
+ return 0, syserror.ENOTTY
+ }
+ return 0, err
+ }
+
+ // Check that calling task's process group is in the TTY session.
+ if task.ThreadGroup().Session() != t.session {
+ return 0, syserror.ENOTTY
+ }
+
+ var pgID kernel.ProcessGroupID
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ // pgID must be non-negative.
+ if pgID < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Process group with pgID must exist in this PID namespace.
+ pidns := task.PIDNamespace()
+ pg := pidns.ProcessGroupWithID(pgID)
+ if pg == nil {
+ return 0, syserror.ESRCH
+ }
+
+ // Check that new process group is in the TTY session.
+ if pg.Session() != t.session {
+ return 0, syserror.EPERM
+ }
+
+ t.fgProcessGroup = pg
+ return 0, nil
+
+ case linux.TIOCGWINSZ:
+ // Args: struct winsize *argp
+ // Get window size.
+ winsize, err := ioctlGetWinsize(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSWINSZ:
+ // Args: const struct winsize *argp
+ // Set window size.
+
+ // Unlike setting the termios, any process group (even background ones) can
+ // set the winsize.
+
+ var winsize linux.Winsize
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetWinsize(fd, &winsize)
+ return 0, err
+
+ // Unimplemented commands.
+ case linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCNOTTY,
+ linux.TIOCSCTTY,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ fallthrough
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// checkChange checks that the process group is allowed to read, write, or
+// change the state of the TTY.
+//
+// This corresponds to Linux drivers/tty/tty_io.c:tty_check_change(). The logic
+// is a bit convoluted, but documented inline.
+//
+// Preconditions: t.mu must be held.
+func (t *ttyFD) checkChange(ctx context.Context, sig linux.Signal) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ // No task? Linux does not have an analog for this case, but
+ // tty_check_change is more of a blacklist of cases than a
+ // whitelist, and is surprisingly permissive. Allowing the
+ // change seems most appropriate.
+ return nil
+ }
+
+ tg := task.ThreadGroup()
+ pg := tg.ProcessGroup()
+
+ // If the session for the task is different than the session for the
+ // controlling TTY, then the change is allowed. Seems like a bad idea,
+ // but that's exactly what linux does.
+ if tg.Session() != t.fgProcessGroup.Session() {
+ return nil
+ }
+
+ // If we are the foreground process group, then the change is allowed.
+ if pg == t.fgProcessGroup {
+ return nil
+ }
+
+ // We are not the foreground process group.
+
+ // Is the provided signal blocked or ignored?
+ if (task.SignalMask()&linux.SignalSetOf(sig) != 0) || tg.SignalHandlers().IsIgnored(sig) {
+ // If the signal is SIGTTIN, then we are attempting to read
+ // from the TTY. Don't send the signal and return EIO.
+ if sig == linux.SIGTTIN {
+ return syserror.EIO
+ }
+
+ // Otherwise, we are writing or changing terminal state. This is allowed.
+ return nil
+ }
+
+ // If the process group is an orphan, return EIO.
+ if pg.IsOrphan() {
+ return syserror.EIO
+ }
+
+ // Otherwise, send the signal to the process group and return ERESTARTSYS.
+ //
+ // Note that Linux also unconditionally sets TIF_SIGPENDING on current,
+ // but this isn't necessary in gVisor because the rationale given in
+ // 040b6362d58f "tty: fix leakage of -ERESTARTSYS to userland" doesn't
+ // apply: the sentry will handle -ERESTARTSYS in
+ // kernel.runApp.execute() even if the kernel.Task isn't interrupted.
+ //
+ // Linux ignores the result of kill_pgrp().
+ _ = pg.SendSignal(kernel.SignalInfoPriv(sig))
+ return kernel.ERESTARTSYS
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
new file mode 100644
index 000000000..2bc757b1a
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func toTimespec(ts linux.StatxTimestamp, omit bool) syscall.Timespec {
+ if omit {
+ return syscall.Timespec{
+ Sec: 0,
+ Nsec: unix.UTIME_OMIT,
+ }
+ }
+ return syscall.Timespec{
+ Sec: ts.Sec,
+ Nsec: int64(ts.Nsec),
+ }
+}
+
+func unixToLinuxStatxTimestamp(ts unix.StatxTimestamp) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: ts.Sec, Nsec: ts.Nsec}
+}
+
+func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)}
+}
+
+// wouldBlock returns true for file types that can return EWOULDBLOCK
+// for blocking operations, e.g. pipes, character devices, and sockets.
+func wouldBlock(fileType uint32) bool {
+ return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
+}
+
+// canMap returns true if a file with fileType is allowed to be memory mapped.
+// This is ported over from VFS1, but it's probably not the best way for us
+// to check if a file can be memory mapped.
+func canMap(fileType uint32) bool {
+ // TODO(gvisor.dev/issue/1672): Also allow "special files" to be mapped (see fs/host:canMap()).
+ //
+ // TODO(b/38213152): Some obscure character devices can be mapped.
+ return fileType == syscall.S_IFREG
+}
+
+// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
+// If so, they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK
+}
diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/pkg/sentry/fsimpl/host/util_unsafe.go
index 4d54b8b8f..5136ac844 100644
--- a/pkg/sentry/kernel/pipe/buffer_test.go
+++ b/pkg/sentry/fsimpl/host/util_unsafe.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,21 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package pipe
+package host
import (
- "testing"
+ "syscall"
"unsafe"
-
- "gvisor.dev/gvisor/pkg/usermem"
)
-func TestBufferSize(t *testing.T) {
- bufferSize := unsafe.Sizeof(buffer{})
- if bufferSize < usermem.PageSize {
- t.Errorf("buffer is less than a page")
- }
- if bufferSize > (2 * usermem.PageSize) {
- t.Errorf("buffer is greater than two pages")
+func setTimestamps(fd int, ts *[2]syscall.Timespec) error {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_UTIMENSAT,
+ uintptr(fd),
+ 0, /* path */
+ uintptr(unsafe.Pointer(ts)),
+ 0, /* flags */
+ 0, 0)
+ if errno != 0 {
+ return errno
}
+ return nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index d092ccb2a..d8bddbafa 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -61,9 +61,10 @@ func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vf
return &fd.vfsfd, nil
}
-// SetStat implements Inode.SetStat.
-func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error {
- // DynamicBytesFiles are immutable.
+// SetStat implements Inode.SetStat. By default DynamicBytesFile doesn't allow
+// inode attributes to be changed. Override SetStat() making it call
+// f.InodeAttrs to allow it.
+func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
@@ -122,7 +123,7 @@ func (fd *DynamicBytesFD) Release() {}
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
- return fd.inode.Stat(fs), nil
+ return fd.inode.Stat(fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 5650512e0..bfa786c88 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -17,6 +17,7 @@ package kernfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -107,9 +108,13 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
fs.mu.Lock()
defer fs.mu.Unlock()
+ opts := vfs.StatOptions{Mask: linux.STATX_INO}
// Handle ".".
if fd.off == 0 {
- stat := fd.inode().Stat(vfsFS)
+ stat, err := fd.inode().Stat(vfsFS, opts)
+ if err != nil {
+ return err
+ }
dirent := vfs.Dirent{
Name: ".",
Type: linux.DT_DIR,
@@ -125,7 +130,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// Handle "..".
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode
- stat := parentInode.Stat(vfsFS)
+ stat, err := parentInode.Stat(vfsFS, opts)
+ if err != nil {
+ return err
+ }
dirent := vfs.Dirent{
Name: "..",
Type: linux.FileMode(stat.Mode).DirentType(),
@@ -146,7 +154,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
childIdx := fd.off - 2
for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
inode := it.Dentry.Impl().(*Dentry).inode
- stat := inode.Stat(vfsFS)
+ stat, err := inode.Stat(vfsFS, opts)
+ if err != nil {
+ return err
+ }
dirent := vfs.Dirent{
Name: it.Name,
Type: linux.FileMode(stat.Mode).DirentType(),
@@ -190,12 +201,12 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int
func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.filesystem()
inode := fd.inode()
- return inode.Stat(fs), nil
+ return inode.Stat(fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- fs := fd.filesystem()
+ creds := auth.CredentialsFromContext(ctx)
inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
- return inode.SetStat(fs, opts)
+ return inode.SetStat(ctx, fd.filesystem(), creds, opts)
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 292f58afd..31da8b511 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -12,16 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// This file implements vfs.FilesystemImpl for kernfs.
-
package kernfs
+// This file implements vfs.FilesystemImpl for kernfs.
+
import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -229,6 +230,19 @@ func (fs *Filesystem) Sync(ctx context.Context) error {
return nil
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ defer fs.processDeferredDecRefs()
+
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return err
+ }
+ return inode.CheckPermissions(ctx, creds, ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
fs.mu.RLock()
@@ -622,7 +636,7 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
if opts.Stat.Mask == 0 {
return nil
}
- return inode.SetStat(fs.VFSFilesystem(), opts)
+ return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts)
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -634,7 +648,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
- return inode.Stat(fs.VFSFilesystem()), nil
+ return inode.Stat(fs.VFSFilesystem(), opts)
}
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 099d70a16..5c84b10c9 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -36,20 +36,20 @@ type InodeNoopRefCount struct {
}
// IncRef implements Inode.IncRef.
-func (n *InodeNoopRefCount) IncRef() {
+func (InodeNoopRefCount) IncRef() {
}
// DecRef implements Inode.DecRef.
-func (n *InodeNoopRefCount) DecRef() {
+func (InodeNoopRefCount) DecRef() {
}
// TryIncRef implements Inode.TryIncRef.
-func (n *InodeNoopRefCount) TryIncRef() bool {
+func (InodeNoopRefCount) TryIncRef() bool {
return true
}
// Destroy implements Inode.Destroy.
-func (n *InodeNoopRefCount) Destroy() {
+func (InodeNoopRefCount) Destroy() {
}
// InodeDirectoryNoNewChildren partially implements the Inode interface.
@@ -58,27 +58,27 @@ func (n *InodeNoopRefCount) Destroy() {
type InodeDirectoryNoNewChildren struct{}
// NewFile implements Inode.NewFile.
-func (*InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewDir implements Inode.NewDir.
-func (*InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewLink implements Inode.NewLink.
-func (*InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewSymlink implements Inode.NewSymlink.
-func (*InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewNode implements Inode.NewNode.
-func (*InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
@@ -90,62 +90,62 @@ type InodeNotDirectory struct {
}
// HasChildren implements Inode.HasChildren.
-func (*InodeNotDirectory) HasChildren() bool {
+func (InodeNotDirectory) HasChildren() bool {
return false
}
// NewFile implements Inode.NewFile.
-func (*InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
panic("NewFile called on non-directory inode")
}
// NewDir implements Inode.NewDir.
-func (*InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
panic("NewDir called on non-directory inode")
}
// NewLink implements Inode.NewLinkink.
-func (*InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
panic("NewLink called on non-directory inode")
}
// NewSymlink implements Inode.NewSymlink.
-func (*InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
panic("NewSymlink called on non-directory inode")
}
// NewNode implements Inode.NewNode.
-func (*InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
panic("NewNode called on non-directory inode")
}
// Unlink implements Inode.Unlink.
-func (*InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
+func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
panic("Unlink called on non-directory inode")
}
// RmDir implements Inode.RmDir.
-func (*InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
+func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
panic("RmDir called on non-directory inode")
}
// Rename implements Inode.Rename.
-func (*InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
+func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
panic("Rename called on non-directory inode")
}
// Lookup implements Inode.Lookup.
-func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
panic("Lookup called on non-directory inode")
}
// IterDirents implements Inode.IterDirents.
-func (*InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
panic("IterDirents called on non-directory inode")
}
// Valid implements Inode.Valid.
-func (*InodeNotDirectory) Valid(context.Context) bool {
+func (InodeNotDirectory) Valid(context.Context) bool {
return true
}
@@ -157,17 +157,17 @@ func (*InodeNotDirectory) Valid(context.Context) bool {
type InodeNoDynamicLookup struct{}
// Lookup implements Inode.Lookup.
-func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
return nil, syserror.ENOENT
}
// IterDirents implements Inode.IterDirents.
-func (*InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
return offset, nil
}
// Valid implements Inode.Valid.
-func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool {
+func (InodeNoDynamicLookup) Valid(ctx context.Context) bool {
return true
}
@@ -177,7 +177,7 @@ func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool {
type InodeNotSymlink struct{}
// Readlink implements Inode.Readlink.
-func (*InodeNotSymlink) Readlink(context.Context) (string, error) {
+func (InodeNotSymlink) Readlink(context.Context) (string, error) {
return "", syserror.EINVAL
}
@@ -219,7 +219,7 @@ func (a *InodeAttrs) Mode() linux.FileMode {
// Stat partially implements Inode.Stat. Note that this function doesn't provide
// all the stat fields, and the embedder should consider extending the result
// with filesystem-specific fields.
-func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx {
+func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
stat.Ino = atomic.LoadUint64(&a.ino)
@@ -228,13 +228,23 @@ func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx {
stat.GID = atomic.LoadUint32(&a.gid)
stat.Nlink = atomic.LoadUint32(&a.nlink)
- // TODO: Implement other stat fields like timestamps.
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
- return stat
+ return stat, nil
}
// SetStat implements Inode.SetStat.
-func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
+func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
+ return syserror.EPERM
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ return err
+ }
+
stat := opts.Stat
if stat.Mask&linux.STATX_MODE != 0 {
for {
@@ -256,19 +266,17 @@ func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
// Note that not all fields are modifiable. For example, the file type and
// inode numbers are immutable after node creation.
- // TODO: Implement other stat fields like timestamps.
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
return nil
}
// CheckPermissions implements Inode.CheckPermissions.
func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
- mode := a.Mode()
return vfs.GenericCheckPermissions(
creds,
ats,
- mode.FileType() == linux.ModeDirectory,
- uint16(mode),
+ a.Mode(),
auth.KUID(atomic.LoadUint32(&a.uid)),
auth.KGID(atomic.LoadUint32(&a.gid)),
)
@@ -554,3 +562,16 @@ func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs
fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// AlwaysValid partially implements kernfs.inodeDynamicLookup.
+type AlwaysValid struct{}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (*AlwaysValid) Valid(context.Context) bool {
+ return true
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index c74fa999b..794e38908 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -176,8 +176,6 @@ type Dentry struct {
vfsd vfs.Dentry
inode Inode
- refs uint64
-
// flags caches useful information about the dentry from the inode. See the
// dflags* consts above. Must be accessed by atomic ops.
flags uint32
@@ -302,7 +300,8 @@ type Inode interface {
// this inode. The returned file description should hold a reference on the
// inode for its lifetime.
//
- // Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry.
+ // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing
+ // the inode on which Open() is being called.
Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
}
@@ -320,7 +319,7 @@ type inodeMetadata interface {
// CheckPermissions checks that creds may access this inode for the
// requested access type, per the the rules of
// fs/namei.c:generic_permission().
- CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error
+ CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error
// Mode returns the (struct stat)::st_mode value for this inode. This is
// separated from Stat for performance.
@@ -328,11 +327,13 @@ type inodeMetadata interface {
// Stat returns the metadata for this inode. This corresponds to
// vfs.FilesystemImpl.StatAt.
- Stat(fs *vfs.Filesystem) linux.Statx
+ Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
// SetStat updates the metadata for this inode. This corresponds to
- // vfs.FilesystemImpl.SetStatAt.
- SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error
+ // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking
+ // if the operation can be performed (see vfs.CheckSetStat() for common
+ // checks).
+ SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error
}
// Precondition: All methods in this interface may only be called on directory
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 0459fb305..fb0d25ad7 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -91,7 +91,7 @@ type attrs struct {
kernfs.InodeAttrs
}
-func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error {
+func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 0ee7eb9b7..5918d3309 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -18,6 +18,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// StaticSymlink provides an Inode implementation for symlinks that point to
@@ -52,3 +54,8 @@ func (s *StaticSymlink) Init(creds *auth.Credentials, ino uint64, target string)
func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
return s.target, nil
}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index a83245866..8156984eb 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -8,10 +8,11 @@ go_library(
"filesystem.go",
"subtasks.go",
"task.go",
+ "task_fds.go",
"task_files.go",
+ "task_net.go",
"tasks.go",
"tasks_files.go",
- "tasks_net.go",
"tasks_sys.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -19,8 +20,10 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/log",
+ "//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/fs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
@@ -53,6 +56,7 @@ go_test(
"//pkg/fspath",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index f3f4e49b4..a21313666 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -34,6 +35,7 @@ type subtasksInode struct {
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeAttrs
kernfs.OrderedChildren
+ kernfs.AlwaysValid
task *kernel.Task
pidns *kernel.PIDNamespace
@@ -61,11 +63,6 @@ func newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, inoGen InoGenera
return dentry
}
-// Valid implements kernfs.inodeDynamicLookup.
-func (i *subtasksInode) Valid(ctx context.Context) bool {
- return true
-}
-
// Lookup implements kernfs.inodeDynamicLookup.
func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
tid, err := strconv.ParseUint(name, 10, 32)
@@ -121,8 +118,18 @@ func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.O
}
// Stat implements kernfs.Inode.
-func (i *subtasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
- stat := i.InodeAttrs.Stat(vsfs)
- stat.Nlink += uint32(i.task.ThreadGroup().Count())
- return stat
+func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ stat.Nlink += uint32(i.task.ThreadGroup().Count())
+ }
+ return stat, nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 2d814668a..aee2a4392 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -45,28 +45,31 @@ var _ kernfs.Inode = (*taskInode)(nil)
func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry {
contents := map[string]*kernfs.Dentry{
- "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}),
- "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
- "comm": newComm(task, inoGen.NextIno(), 0444),
- "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
- //"exe": newExe(t, msrc),
- //"fd": newFdDir(t, msrc),
- //"fdinfo": newFdInfoDir(t, msrc),
- "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}),
- "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)),
- "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}),
- //"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
- //"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+ "comm": newComm(task, inoGen.NextIno(), 0444),
+ "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
+ "exe": newExeSymlink(task, inoGen.NextIno()),
+ "fd": newFDDirInode(task, inoGen),
+ "fdinfo": newFDInfoDirInode(task, inoGen),
+ "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}),
+ "mountinfo": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountInfoData{task: task}),
+ "mounts": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountsData{task: task}),
+ "net": newTaskNetDir(task, inoGen),
"ns": newTaskOwnedDir(task, inoGen.NextIno(), 0511, map[string]*kernfs.Dentry{
"net": newNamespaceSymlink(task, inoGen.NextIno(), "net"),
"pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"),
"user": newNamespaceSymlink(task, inoGen.NextIno(), "user"),
}),
- "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}),
- "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
- "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}),
- "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
- "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}),
+ "oom_score": newTaskOwnedFile(task, inoGen.NextIno(), 0444, newStaticFile("0\n")),
+ "oom_score_adj": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &oomScoreAdj{task: task}),
+ "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}),
+ "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}),
+ "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
contents["task"] = newSubtasks(task, pidns, inoGen, cgroupControllers)
@@ -104,13 +107,9 @@ func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenO
return fd.VFSFileDescription(), nil
}
-// SetStat implements kernfs.Inode.
-func (i *taskInode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
- stat := opts.Stat
- if stat.Mask&linux.STATX_MODE != 0 {
- return syserror.EPERM
- }
- return nil
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
}
// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
@@ -152,26 +151,28 @@ func newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, childre
}
// Stat implements kernfs.Inode.
-func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx {
- stat := i.Inode.Stat(fs)
- uid, gid := i.getOwner(linux.FileMode(stat.Mode))
- stat.UID = uint32(uid)
- stat.GID = uint32(gid)
- return stat
+func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.Inode.Stat(fs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ uid, gid := i.getOwner(linux.FileMode(stat.Mode))
+ if opts.Mask&linux.STATX_UID != 0 {
+ stat.UID = uint32(uid)
+ }
+ if opts.Mask&linux.STATX_GID != 0 {
+ stat.GID = uint32(gid)
+ }
+ }
+ return stat, nil
}
// CheckPermissions implements kernfs.Inode.
func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
mode := i.Mode()
uid, gid := i.getOwner(mode)
- return vfs.GenericCheckPermissions(
- creds,
- ats,
- mode.FileType() == linux.ModeDirectory,
- uint16(mode),
- uid,
- gid,
- )
+ return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid)
}
func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) {
@@ -234,7 +235,7 @@ func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentr
// member, there is one entry containing three colon-separated fields:
// hierarchy-ID:controller-list:cgroup-path"
func newCgroupData(controllers map[string]string) dynamicInode {
- buf := bytes.Buffer{}
+ var buf bytes.Buffer
// The hierarchy ids must be positive integers (for cgroup v1), but the
// exact number does not matter, so long as they are unique. We can
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
new file mode 100644
index 000000000..76bfc5307
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -0,0 +1,287 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+type fdDir struct {
+ inoGen InoGenerator
+ task *kernel.Task
+
+ // When produceSymlinks is set, dirents produces for the FDs are reported
+ // as symlink. Otherwise, they are reported as regular files.
+ produceSymlink bool
+}
+
+func (i *fdDir) lookup(name string) (*vfs.FileDescription, kernel.FDFlags, error) {
+ fd, err := strconv.ParseUint(name, 10, 64)
+ if err != nil {
+ return nil, kernel.FDFlags{}, syserror.ENOENT
+ }
+
+ var (
+ file *vfs.FileDescription
+ flags kernel.FDFlags
+ )
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ file, flags = fdTable.GetVFS2(int32(fd))
+ }
+ })
+ if file == nil {
+ return nil, kernel.FDFlags{}, syserror.ENOENT
+ }
+ return file, flags, nil
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, absOffset, relOffset int64) (int64, error) {
+ var fds []int32
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.GetFDs()
+ }
+ })
+
+ offset := absOffset + relOffset
+ typ := uint8(linux.DT_REG)
+ if i.produceSymlink {
+ typ = linux.DT_LNK
+ }
+
+ // Find the appropriate starting point.
+ idx := sort.Search(len(fds), func(i int) bool { return fds[i] >= int32(relOffset) })
+ if idx >= len(fds) {
+ return offset, nil
+ }
+ for _, fd := range fds[idx:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(fd), 10),
+ Type: typ,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// fdDirInode represents the inode for /proc/[pid]/fd directory.
+//
+// +stateify savable
+type fdDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdDirInode)(nil)
+
+func newFDDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
+ inode := &fdDirInode{
+ fdDir: fdDir{
+ inoGen: inoGen,
+ task: task,
+ produceSymlink: true,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ file, _, err := i.lookup(name)
+ if err != nil {
+ return nil, err
+ }
+ taskDentry := newFDSymlink(i.task.Credentials(), file, i.inoGen.NextIno())
+ return taskDentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdDirInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ return fd.VFSFileDescription(), nil
+}
+
+// CheckPermissions implements kernfs.Inode.
+//
+// This is to match Linux, which uses a special permission handler to guarantee
+// that a process can still access /proc/self/fd after it has executed
+// setuid. See fs/proc/fd.c:proc_fd_permission.
+func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ err := i.InodeAttrs.CheckPermissions(ctx, creds, ats)
+ if err == nil {
+ // Access granted, no extra check needed.
+ return nil
+ }
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ // Allow access if the task trying to access it is in the thread group
+ // corresponding to this directory.
+ if i.task.ThreadGroup() == t.ThreadGroup() {
+ // Access granted (overridden).
+ return nil
+ }
+ }
+ return err
+}
+
+// fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file.
+//
+// +stateify savable
+type fdSymlink struct {
+ refs.AtomicRefCount
+ kernfs.InodeAttrs
+ kernfs.InodeSymlink
+
+ file *vfs.FileDescription
+}
+
+var _ kernfs.Inode = (*fdSymlink)(nil)
+
+func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64) *kernfs.Dentry {
+ file.IncRef()
+ inode := &fdSymlink{file: file}
+ inode.Init(creds, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *fdSymlink) Readlink(ctx context.Context) (string, error) {
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+
+ vfsObj := s.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem()
+ return vfsObj.PathnameWithDeleted(ctx, root, s.file.VirtualDentry())
+}
+
+func (s *fdSymlink) DecRef() {
+ s.AtomicRefCount.DecRefWithDestructor(func() {
+ s.Destroy()
+ })
+}
+
+func (s *fdSymlink) Destroy() {
+ s.file.DecRef()
+}
+
+// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory.
+//
+// +stateify savable
+type fdInfoDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdInfoDirInode)(nil)
+
+func newFDInfoDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
+ inode := &fdInfoDirInode{
+ fdDir: fdDir{
+ inoGen: inoGen,
+ task: task,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ file, flags, err := i.lookup(name)
+ if err != nil {
+ return nil, err
+ }
+
+ data := &fdInfoData{file: file, flags: flags}
+ dentry := newTaskOwnedFile(i.task, i.inoGen.NextIno(), 0444, data)
+ return dentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdInfoDirInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ return fd.VFSFileDescription(), nil
+}
+
+// fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd].
+//
+// +stateify savable
+type fdInfoData struct {
+ kernfs.DynamicBytesFile
+ refs.AtomicRefCount
+
+ file *vfs.FileDescription
+ flags kernel.FDFlags
+}
+
+var _ dynamicInode = (*fdInfoData)(nil)
+
+func (d *fdInfoData) DecRef() {
+ d.AtomicRefCount.DecRefWithDestructor(d.destroy)
+}
+
+func (d *fdInfoData) destroy() {
+ d.file.DecRef()
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/121266871): Include pos, locks, and other data. For now we only
+ // have flags.
+ // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
+ flags := uint(d.file.StatusFlags()) | d.flags.ToLinuxFileFlags()
+ fmt.Fprintf(buf, "flags:\t0%o\n", flags)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index efd3b3453..8c743df8d 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -18,10 +18,14 @@ import (
"bytes"
"fmt"
"io"
+ "sort"
+ "strings"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -496,7 +500,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
+// ioUsage is the /proc/[pid]/io and /proc/[pid]/task/[tid]/io data provider.
type ioUsage interface {
// IOUsage returns the io usage data.
IOUsage() *usage.IO
@@ -525,3 +529,293 @@ func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
return nil
}
+
+// oomScoreAdj is a stub of the /proc/<pid>/oom_score_adj file.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ vfs.WritableDynamicBytesSource = (*oomScoreAdj)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
+ fmt.Fprintf(buf, "%d\n", o.task.OOMScoreAdj())
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := o.task.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// exeSymlink is an symlink for the /proc/[pid]/exe file.
+//
+// +stateify savable
+type exeSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+}
+
+var _ kernfs.Inode = (*exeSymlink)(nil)
+
+func newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+ inode := &exeSymlink{task: task}
+ inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Readlink implements kernfs.Inode.
+func (s *exeSymlink) Readlink(ctx context.Context) (string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return "", syserror.EACCES
+ }
+
+ // Pull out the executable for /proc/[pid]/exe.
+ exec, err := s.executable()
+ if err != nil {
+ return "", err
+ }
+ defer exec.DecRef()
+
+ return exec.PathnameWithDeleted(ctx), nil
+}
+
+func (s *exeSymlink) executable() (file fsbridge.File, err error) {
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ mm := t.MemoryManager()
+ if mm == nil {
+ // TODO(b/34851096): Check shouldn't allow Readlink once the
+ // Task is zombied.
+ err = syserror.EACCES
+ return
+ }
+
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ file = mm.Executable()
+ if file == nil {
+ err = syserror.ENOENT
+ }
+ })
+ return
+}
+
+// forEachMountSource runs f for the process root mount and each mount that is
+// a descendant of the root.
+func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
+ var fsctx *kernel.FSContext
+ t.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return
+ }
+
+ // All mount points must be relative to the rootDir, and mounts outside
+ // will be excluded.
+ rootDir := fsctx.RootDirectory()
+ if rootDir == nil {
+ // The task has been destroyed. Nothing to show here.
+ return
+ }
+ defer rootDir.DecRef()
+
+ mnt := t.MountNamespace().FindMount(rootDir)
+ if mnt == nil {
+ // Has it just been unmounted?
+ return
+ }
+ ms := t.MountNamespace().AllMountsUnder(mnt)
+ sort.Slice(ms, func(i, j int) bool {
+ return ms[i].ID < ms[j].ID
+ })
+ for _, m := range ms {
+ mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
+ mountPath, desc := mroot.FullName(rootDir)
+ mroot.DecRef()
+ if !desc {
+ // MountSources that are not descendants of the chroot jail are ignored.
+ continue
+ }
+ fn(mountPath, m)
+ }
+}
+
+// mountInfoData is used to implement /proc/[pid]/mountinfo.
+//
+// +stateify savable
+type mountInfoData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ forEachMount(i.task, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef()
+
+ // Format:
+ // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
+ // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
+
+ // (1) MountSource ID.
+ fmt.Fprintf(buf, "%d ", m.ID)
+
+ // (2) Parent ID (or this ID if there is no parent).
+ pID := m.ID
+ if !m.IsRoot() && !m.IsUndo() {
+ pID = m.ParentID
+ }
+ fmt.Fprintf(buf, "%d ", pID)
+
+ // (3) Major:Minor device ID. We don't have a superblock, so we
+ // just use the root inode device number.
+ sa := mroot.Inode.StableAttr
+ fmt.Fprintf(buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
+
+ // (4) Root: the pathname of the directory in the filesystem
+ // which forms the root of this mount.
+ //
+ // NOTE(b/78135857): This will always be "/" until we implement
+ // bind mounts.
+ fmt.Fprintf(buf, "/ ")
+
+ // (5) Mount point (relative to process root).
+ fmt.Fprintf(buf, "%s ", mountPath)
+
+ // (6) Mount options.
+ flags := mroot.Inode.MountSource.Flags
+ opts := "rw"
+ if flags.ReadOnly {
+ opts = "ro"
+ }
+ if flags.NoAtime {
+ opts += ",noatime"
+ }
+ if flags.NoExec {
+ opts += ",noexec"
+ }
+ fmt.Fprintf(buf, "%s ", opts)
+
+ // (7) Optional fields: zero or more fields of the form "tag[:value]".
+ // (8) Separator: the end of the optional fields is marked by a single hyphen.
+ fmt.Fprintf(buf, "- ")
+
+ // (9) Filesystem type.
+ fmt.Fprintf(buf, "%s ", mroot.Inode.MountSource.FilesystemType)
+
+ // (10) Mount source: filesystem-specific information or "none".
+ fmt.Fprintf(buf, "none ")
+
+ // (11) Superblock options, and final newline.
+ fmt.Fprintf(buf, "%s\n", superBlockOpts(mountPath, mroot.Inode.MountSource))
+ })
+ return nil
+}
+
+func superBlockOpts(mountPath string, msrc *fs.MountSource) string {
+ // gVisor doesn't (yet) have a concept of super block options, so we
+ // use the ro/rw bit from the mount flag.
+ opts := "rw"
+ if msrc.Flags.ReadOnly {
+ opts = "ro"
+ }
+
+ // NOTE(b/147673608): If the mount is a cgroup, we also need to include
+ // the cgroup name in the options. For now we just read that from the
+ // path.
+ // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we
+ // should get this value from the cgroup itself, and not rely on the
+ // path.
+ if msrc.FilesystemType == "cgroup" {
+ splitPath := strings.Split(mountPath, "/")
+ cgroupType := splitPath[len(splitPath)-1]
+ opts += "," + cgroupType
+ }
+ return opts
+}
+
+// mountsData is used to implement /proc/[pid]/mounts.
+//
+// +stateify savable
+type mountsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ forEachMount(i.task, func(mountPath string, m *fs.Mount) {
+ // Format:
+ // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order>
+ //
+ // We use the filesystem name as the first field, since there
+ // is no real block device we can point to, and we also should
+ // not expose anything about the remote filesystem.
+ //
+ // Only ro/rw option is supported for now.
+ //
+ // The "needs dump"and fsck flags are always 0, which is allowed.
+ root := m.Root()
+ if root == nil {
+ return // No longer valid.
+ }
+ defer root.DecRef()
+
+ flags := root.Inode.MountSource.Flags
+ opts := "rw"
+ if flags.ReadOnly {
+ opts = "ro"
+ }
+ fmt.Fprintf(buf, "%s %s %s %s %d %d\n", "none", mountPath, root.Inode.MountSource.FilesystemType, opts, 0, 0)
+ })
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index d4e1812d8..373a7b17d 100644
--- a/pkg/sentry/fsimpl/proc/tasks_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -37,12 +37,13 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
+func newTaskNetDir(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
+ k := task.Kernel()
+ pidns := task.PIDNamespace()
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+
var contents map[string]*kernfs.Dentry
- // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
- // network namespace of the calling process. We should make this per-process,
- // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
- if stack := k.RootNetworkNamespace().Stack(); stack != nil {
+ if stack := task.NetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
@@ -53,6 +54,8 @@ func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
)
psched := fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond))
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
contents = map[string]*kernfs.Dentry{
"dev": newDentry(root, inoGen.NextIno(), 0444, &netDevData{stack: stack}),
"snmp": newDentry(root, inoGen.NextIno(), 0444, &netSnmpData{stack: stack}),
@@ -84,7 +87,7 @@ func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
}
}
- return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, contents)
+ return newTaskOwnedDir(task, inoGen.NextIno(), 0555, contents)
}
// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6.
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 10c08fa90..9f2ef8200 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -46,6 +46,7 @@ type tasksInode struct {
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeAttrs
kernfs.OrderedChildren
+ kernfs.AlwaysValid
inoGen InoGenerator
pidns *kernel.PIDNamespace
@@ -66,23 +67,23 @@ var _ kernfs.Inode = (*tasksInode)(nil)
func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]*kernfs.Dentry{
- "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(cpuInfoData(k))),
- //"filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}),
- "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}),
- "sys": newSysDir(root, inoGen, k),
- "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
- "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
- "net": newNetDir(root, inoGen, k),
- "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}),
- "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
- "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}),
+ "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}),
+ "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}),
+ "sys": newSysDir(root, inoGen, k),
+ "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
+ "net": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/net"),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
}
inode := &tasksInode{
pidns: pidns,
inoGen: inoGen,
- selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
- threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
+ selfSymlink: newSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(),
+ threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(),
cgroupControllers: cgroupControllers,
}
inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555)
@@ -121,11 +122,6 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
return taskDentry.VFSDentry(), nil
}
-// Valid implements kernfs.inodeDynamicLookup.
-func (i *tasksInode) Valid(ctx context.Context) bool {
- return true
-}
-
// IterDirents implements kernfs.inodeDynamicLookup.
func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
@@ -211,17 +207,36 @@ func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.Open
return fd.VFSFileDescription(), nil
}
-func (i *tasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
- stat := i.InodeAttrs.Stat(vsfs)
+func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
- // Add dynamic children to link count.
- for _, tg := range i.pidns.ThreadGroups() {
- if leader := tg.Leader(); leader != nil {
- stat.Nlink++
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ // Add dynamic children to link count.
+ for _, tg := range i.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ stat.Nlink++
+ }
}
}
- return stat
+ return stat, nil
+}
+
+// staticFileSetStat implements a special static file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type staticFileSetStat struct {
+ dynamicBytesFileSetAttr
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFileSetStat)(nil)
+
+func newStaticFileSetStat(data string) *staticFileSetStat {
+ return &staticFileSetStat{StaticData: vfs.StaticData{Data: data}}
}
func cpuInfoData(k *kernel.Kernel) string {
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 434998910..882c1981e 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,9 +41,9 @@ type selfSymlink struct {
var _ kernfs.Inode = (*selfSymlink)(nil)
-func newSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+func newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
inode := &selfSymlink{pidns: pidns}
- inode.Init(creds, ino, linux.ModeSymlink|perm)
+ inode.Init(creds, ino, linux.ModeSymlink|0777)
d := &kernfs.Dentry{}
d.Init(inode)
@@ -62,6 +63,11 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
return strconv.FormatUint(uint64(tgid), 10), nil
}
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
type threadSelfSymlink struct {
kernfs.InodeAttrs
kernfs.InodeNoopRefCount
@@ -72,9 +78,9 @@ type threadSelfSymlink struct {
var _ kernfs.Inode = (*threadSelfSymlink)(nil)
-func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
inode := &threadSelfSymlink{pidns: pidns}
- inode.Init(creds, ino, linux.ModeSymlink|perm)
+ inode.Init(creds, ino, linux.ModeSymlink|0777)
d := &kernfs.Dentry{}
d.Init(inode)
@@ -95,6 +101,23 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
return fmt.Sprintf("%d/task/%d", tgid, tid), nil
}
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// dynamicBytesFileSetAttr implements a special file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type dynamicBytesFileSetAttr struct {
+ kernfs.DynamicBytesFile
+}
+
+// SetStat implements Inode.SetStat.
+func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts)
+}
+
// cpuStats contains the breakdown of CPU time for /proc/stat.
type cpuStats struct {
// user is time spent in userspace tasks with non-positive niceness.
@@ -137,22 +160,20 @@ func (c cpuStats) String() string {
//
// +stateify savable
type statData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*statData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+func (*statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// TODO(b/37226836): We currently export only zero CPU stats. We could
// at least provide some aggregate stats.
var cpu cpuStats
fmt.Fprintf(buf, "cpu %s\n", cpu)
- for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
+ k := kernel.KernelFromContext(ctx)
+ for c, max := uint(0), k.ApplicationCores(); c < max; c++ {
fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
}
@@ -176,7 +197,7 @@ func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "ctxt 0\n")
// CLOCK_REALTIME timestamp from boot, in seconds.
- fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
+ fmt.Fprintf(buf, "btime %d\n", k.Timekeeper().BootTime().Seconds())
// Total number of clones.
// TODO(b/37226836): Count this.
@@ -203,13 +224,13 @@ func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type loadavgData struct {
- kernfs.DynamicBytesFile
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*loadavgData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+func (*loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// TODO(b/62345059): Include real data in fields.
// Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
// Column 4-5: currently running processes and the total number of processes.
@@ -222,17 +243,15 @@ func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type meminfoData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*meminfoData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- mf := d.k.MemoryFile()
+func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ mf := k.MemoryFile()
mf.UpdateUsage()
snapshot, totalUsage := usage.MemoryAccounting.Copy()
totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
@@ -275,7 +294,7 @@ func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type uptimeData struct {
- kernfs.DynamicBytesFile
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*uptimeData)(nil)
@@ -294,17 +313,15 @@ func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type versionData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*versionData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- init := v.k.GlobalInit()
+func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ init := k.GlobalInit()
if init == nil {
// Attempted to read before the init Task is created. This can
// only occur during startup, which should never need to read
@@ -335,3 +352,19 @@ func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
return nil
}
+
+// filesystemsData backs /proc/filesystems.
+//
+// +stateify savable
+type filesystemsData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*filesystemsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ k.VFS().GenerateProcFilesystems(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index c5d531fe0..d0f97c137 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -47,10 +48,11 @@ var (
var (
tasksStaticFiles = map[string]testutil.DirentType{
"cpuinfo": linux.DT_REG,
+ "filesystems": linux.DT_REG,
"loadavg": linux.DT_REG,
"meminfo": linux.DT_REG,
"mounts": linux.DT_LNK,
- "net": linux.DT_DIR,
+ "net": linux.DT_LNK,
"self": linux.DT_LNK,
"stat": linux.DT_REG,
"sys": linux.DT_DIR,
@@ -63,21 +65,29 @@ var (
"thread-self": threadSelfLink.NextOff,
}
taskStaticFiles = map[string]testutil.DirentType{
- "auxv": linux.DT_REG,
- "cgroup": linux.DT_REG,
- "cmdline": linux.DT_REG,
- "comm": linux.DT_REG,
- "environ": linux.DT_REG,
- "gid_map": linux.DT_REG,
- "io": linux.DT_REG,
- "maps": linux.DT_REG,
- "ns": linux.DT_DIR,
- "smaps": linux.DT_REG,
- "stat": linux.DT_REG,
- "statm": linux.DT_REG,
- "status": linux.DT_REG,
- "task": linux.DT_DIR,
- "uid_map": linux.DT_REG,
+ "auxv": linux.DT_REG,
+ "cgroup": linux.DT_REG,
+ "cmdline": linux.DT_REG,
+ "comm": linux.DT_REG,
+ "environ": linux.DT_REG,
+ "exe": linux.DT_LNK,
+ "fd": linux.DT_DIR,
+ "fdinfo": linux.DT_DIR,
+ "gid_map": linux.DT_REG,
+ "io": linux.DT_REG,
+ "maps": linux.DT_REG,
+ "mountinfo": linux.DT_REG,
+ "mounts": linux.DT_REG,
+ "net": linux.DT_DIR,
+ "ns": linux.DT_DIR,
+ "oom_score": linux.DT_REG,
+ "oom_score_adj": linux.DT_REG,
+ "smaps": linux.DT_REG,
+ "stat": linux.DT_REG,
+ "statm": linux.DT_REG,
+ "status": linux.DT_REG,
+ "task": linux.DT_DIR,
+ "uid_map": linux.DT_REG,
}
)
@@ -93,17 +103,37 @@ func setup(t *testing.T) *testutil.System {
k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- fsOpts := vfs.GetFilesystemOptions{
- InternalData: &InternalData{
- Cgroups: map[string]string{
- "cpuset": "/foo/cpuset",
- "memory": "/foo/memory",
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+ pop := &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil {
+ t.Fatalf("MkDir(/proc): %v", err)
+ }
+
+ pop = &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ mntOpts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: &InternalData{
+ Cgroups: map[string]string{
+ "cpuset": "/foo/cpuset",
+ "memory": "/foo/memory",
+ },
},
},
}
- mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", Name, &fsOpts)
- if err != nil {
- t.Fatalf("NewMountNamespace(): %v", err)
+ if err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil {
+ t.Fatalf("MountAt(/proc): %v", err)
}
return testutil.NewSystem(ctx, t, k.VFS(), mntns)
}
@@ -112,7 +142,7 @@ func TestTasksEmpty(t *testing.T) {
s := setup(t)
defer s.Destroy()
- collector := s.ListDirents(s.PathOpAtRoot("/"))
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
s.AssertAllDirentTypes(collector, tasksStaticFiles)
s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
}
@@ -138,7 +168,7 @@ func TestTasks(t *testing.T) {
expectedDirents[fmt.Sprintf("%d", i+1)] = linux.DT_DIR
}
- collector := s.ListDirents(s.PathOpAtRoot("/"))
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
s.AssertAllDirentTypes(collector, expectedDirents)
s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
@@ -178,7 +208,7 @@ func TestTasks(t *testing.T) {
}
// Test lookup.
- for _, path := range []string{"/1", "/2"} {
+ for _, path := range []string{"/proc/1", "/proc/2"} {
fd, err := s.VFS.OpenAt(
s.Ctx,
s.Creds,
@@ -188,6 +218,7 @@ func TestTasks(t *testing.T) {
if err != nil {
t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err)
}
+ defer fd.DecRef()
buf := make([]byte, 1)
bufIOSeq := usermem.BytesIOSequence(buf)
if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
@@ -198,10 +229,10 @@ func TestTasks(t *testing.T) {
if _, err := s.VFS.OpenAt(
s.Ctx,
s.Creds,
- s.PathOpAtRoot("/9999"),
+ s.PathOpAtRoot("/proc/9999"),
&vfs.OpenOptions{},
); err != syserror.ENOENT {
- t.Fatalf("wrong error from vfsfs.OpenAt(/9999): %v", err)
+ t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err)
}
}
@@ -299,12 +330,13 @@ func TestTasksOffset(t *testing.T) {
fd, err := s.VFS.OpenAt(
s.Ctx,
s.Creds,
- s.PathOpAtRoot("/"),
+ s.PathOpAtRoot("/proc"),
&vfs.OpenOptions{},
)
if err != nil {
t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
}
+ defer fd.DecRef()
if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil {
t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
}
@@ -341,7 +373,7 @@ func TestTask(t *testing.T) {
t.Fatalf("CreateTask(): %v", err)
}
- collector := s.ListDirents(s.PathOpAtRoot("/1"))
+ collector := s.ListDirents(s.PathOpAtRoot("/proc/1"))
s.AssertAllDirentTypes(collector, taskStaticFiles)
}
@@ -359,14 +391,14 @@ func TestProcSelf(t *testing.T) {
collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{
Root: s.Root,
Start: s.Root,
- Path: fspath.Parse("/self/"),
+ Path: fspath.Parse("/proc/self/"),
FollowFinalSymlink: true,
})
s.AssertAllDirentTypes(collector, taskStaticFiles)
}
func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.FileDescription) {
- t.Logf("Iterating: /proc%s", fd.MappedName(ctx))
+ t.Logf("Iterating: %s", fd.MappedName(ctx))
var collector testutil.DirentCollector
if err := fd.IterDirents(ctx, &collector); err != nil {
@@ -409,6 +441,7 @@ func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.F
t.Errorf("vfsfs.OpenAt(%v) failed: %v", childPath, err)
continue
}
+ defer child.DecRef()
stat, err := child.Stat(ctx, vfs.StatOptions{})
if err != nil {
t.Errorf("Stat(%v) failed: %v", childPath, err)
@@ -429,6 +462,22 @@ func TestTree(t *testing.T) {
defer s.Destroy()
k := kernel.KernelFromContext(s.Ctx)
+
+ pop := &vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse("test-file"),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ file, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, opts)
+ if err != nil {
+ t.Fatalf("failed to create test file: %v", err)
+ }
+ defer file.DecRef()
+
var tasks []*kernel.Task
for i := 0; i < 5; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
@@ -436,6 +485,8 @@ func TestTree(t *testing.T) {
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
+ // Add file to populate /proc/[pid]/fd and fdinfo directories.
+ task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{})
tasks = append(tasks, task)
}
@@ -443,11 +494,12 @@ func TestTree(t *testing.T) {
fd, err := s.VFS.OpenAt(
ctx,
auth.CredentialsFromContext(s.Ctx),
- &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/")},
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc")},
&vfs.OpenOptions{},
)
if err != nil {
- t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err)
}
iterateDir(ctx, t, s, fd)
+ fd.DecRef()
}
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index c36c4fa11..7abfd62f2 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -94,15 +94,17 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte
return &d.dentry
}
-// SetStat implements kernfs.Inode.SetStat.
-func (d *dir) SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error {
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
// Open implements kernfs.Inode.Open.
func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts); err != nil {
+ return nil, err
+ }
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index e4f36f4ae..0e4053a46 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -16,12 +16,14 @@ go_library(
"//pkg/cpuid",
"//pkg/fspath",
"//pkg/memutil",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/limits",
"//pkg/sentry/loader",
+ "//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
"//pkg/sentry/platform/kvm",
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 488478e29..c16a36cdb 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -23,13 +23,16 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/time"
@@ -123,10 +126,17 @@ func Boot() (*kernel.Kernel, error) {
// CreateTask creates a new bare bones task for tests.
func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) {
k := kernel.KernelFromContext(ctx)
+ exe, err := newFakeExecutable(ctx, k.VFS(), auth.CredentialsFromContext(ctx), root)
+ if err != nil {
+ return nil, err
+ }
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
+ m.SetExecutable(fsbridge.NewVFSFile(exe))
+
config := &kernel.TaskConfig{
Kernel: k,
ThreadGroup: tc,
- TaskContext: &kernel.TaskContext{Name: name},
+ TaskContext: &kernel.TaskContext{Name: name, MemoryManager: m},
Credentials: auth.CredentialsFromContext(ctx),
NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
@@ -135,10 +145,25 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns
AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
MountNamespaceVFS2: mntns,
FSContext: kernel.NewFSContextVFS2(root, cwd, 0022),
+ FDTable: k.NewFDTable(),
}
return k.TaskSet().NewTask(config)
}
+func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) {
+ const name = "executable"
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ return vfsObj.OpenAt(ctx, creds, pop, opts)
+}
+
func createMemoryFile() (*pgalloc.MemoryFile, error) {
const memfileName = "test-memory"
memfd, err := memutil.CreateMemFD(memfileName, 0)
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
index 84b181b90..83bf885ee 100644
--- a/pkg/sentry/fsimpl/tmpfs/device_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -15,6 +15,8 @@
package tmpfs
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -33,6 +35,14 @@ func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode
major: major,
minor: minor,
}
+ switch kind {
+ case vfs.BlockDevice:
+ mode |= linux.S_IFBLK
+ case vfs.CharDevice:
+ mode |= linux.S_IFCHR
+ default:
+ panic(fmt.Sprintf("invalid DeviceKind: %v", kind))
+ }
file.inode.init(file, fs, creds, mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index b4380af38..37c75ab64 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -34,16 +34,11 @@ type directory struct {
func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode {
dir := &directory{}
- dir.inode.init(dir, fs, creds, mode)
+ dir.inode.init(dir, fs, creds, linux.S_IFDIR|mode)
dir.inode.nlink = 2 // from "." and parent directory or ".." for root
return &dir.inode
}
-func (i *inode) isDir() bool {
- _, ok := i.impl.(*directory)
- return ok
-}
-
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index e1b551422..12cc64385 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -40,7 +41,7 @@ func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
if !d.inode.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
@@ -124,7 +125,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
@@ -154,6 +155,17 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
return create(parent, name)
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return err
+ }
+ return d.inode.checkPermissions(creds, ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
fs.mu.RLock()
@@ -166,7 +178,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.inode.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true /* isDir */); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -289,7 +301,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Reject attempts to open directories with O_CREAT.
@@ -304,7 +316,7 @@ afterTrailingSymlink:
child, err := stepLocked(rp, parent)
if err == syserror.ENOENT {
// Already checked for searchability above; now check for writability.
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -335,7 +347,7 @@ afterTrailingSymlink:
func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
if !afterCreate {
- if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
}
@@ -416,7 +428,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
defer mnt.EndWrite()
oldParent := oldParentVD.Dentry().Impl().(*dentry)
- if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
// Call vfs.Dentry.Child() instead of stepLocked() or rp.ResolveChild(),
@@ -433,7 +445,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if oldParent != newParent {
// Writability is needed to change renamed's "..".
- if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true /* isDir */); err != nil {
+ if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return err
}
}
@@ -443,7 +455,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
}
- if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
replacedVFSD := newParent.vfsd.Child(newName)
@@ -516,7 +528,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
@@ -563,7 +575,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
if err != nil {
return err
}
- return d.inode.setStat(opts.Stat)
+ return d.inode.setStat(ctx, rp.Credentials(), &opts.Stat)
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -609,7 +621,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 0c57fdca3..2c5c739df 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -34,7 +34,7 @@ type namedPipe struct {
// * rp.Mount().CheckBeginWrite() has been called successfully.
func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)}
- file.inode.init(file, fs, creds, mode)
+ file.inode.init(file, fs, creds, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 711442424..26cd65605 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -89,7 +89,7 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
file := &regularFile{
memFile: fs.memFile,
}
- file.inode.init(file, fs, creds, mode)
+ file.inode.init(file, fs, creds, linux.S_IFREG|mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
@@ -308,11 +308,18 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, nil
}
f := fd.inode().impl.(*regularFile)
- end := offset + srclen
- if end < offset {
+ if end := offset + srclen; end < offset {
// Overflow.
return 0, syserror.EFBIG
}
+
+ var err error
+ srclen, err = vfs.CheckLimit(ctx, offset, srclen)
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(srclen)
+
f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index 5246aca84..47e075ed4 100644
--- a/pkg/sentry/fsimpl/tmpfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -15,6 +15,7 @@
package tmpfs
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -27,7 +28,7 @@ func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode
link := &symlink{
target: target,
}
- link.inode.init(link, fs, creds, 0777)
+ link.inode.init(link, fs, creds, linux.S_IFLNK|0777)
link.inode.nlink = 1 // from parent directory
return &link.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 521206305..2f9e6c876 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -144,7 +144,7 @@ type inode struct {
// Inode metadata. Writing multiple fields atomically requires holding
// mu, othewise atomic operations can be used.
mu sync.Mutex
- mode uint32 // excluding file type bits, which are based on impl
+ mode uint32 // file type and mode
nlink uint32 // protected by filesystem.mu instead of inode.mu
uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
gid uint32 // auth.KGID, but ...
@@ -168,6 +168,9 @@ type inode struct {
const maxLinks = math.MaxUint32
func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
+ if mode.FileType() == 0 {
+ panic("file type is required in FileMode")
+ }
i.clock = fs.clock
i.refs = 1
i.mode = uint32(mode)
@@ -242,8 +245,9 @@ func (i *inode) decRef() {
}
}
-func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
- return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
+func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ return vfs.GenericCheckPermissions(creds, ats, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
}
// Go won't inline this function, and returning linux.Statx (which is quite
@@ -269,40 +273,37 @@ func (i *inode) statTo(stat *linux.Statx) {
// TODO(gvisor.dev/issues/1197): Device number.
switch impl := i.impl.(type) {
case *regularFile:
- stat.Mode |= linux.S_IFREG
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
stat.Size = uint64(atomic.LoadUint64(&impl.size))
// In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in
// a uint64 accessed using atomic memory operations to avoid taking
// locks).
stat.Blocks = allocatedBlocksForSize(stat.Size)
- case *directory:
- stat.Mode |= linux.S_IFDIR
case *symlink:
- stat.Mode |= linux.S_IFLNK
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
stat.Size = uint64(len(impl.target))
stat.Blocks = allocatedBlocksForSize(stat.Size)
- case *namedPipe:
- stat.Mode |= linux.S_IFIFO
case *deviceFile:
- switch impl.kind {
- case vfs.BlockDevice:
- stat.Mode |= linux.S_IFBLK
- case vfs.CharDevice:
- stat.Mode |= linux.S_IFCHR
- }
stat.RdevMajor = impl.major
stat.RdevMinor = impl.minor
+ case *directory, *namedPipe:
+ // Nothing to do.
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
}
-func (i *inode) setStat(stat linux.Statx) error {
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error {
if stat.Mask == 0 {
return nil
}
+ if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE) != 0 {
+ return syserror.EPERM
+ }
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ return err
+ }
i.mu.Lock()
var (
needsMtimeBump bool
@@ -310,7 +311,8 @@ func (i *inode) setStat(stat linux.Statx) error {
)
mask := stat.Mask
if mask&linux.STATX_MODE != 0 {
- atomic.StoreUint32(&i.mode, uint32(stat.Mode))
+ ft := atomic.LoadUint32(&i.mode) & linux.S_IFMT
+ atomic.StoreUint32(&i.mode, ft|uint32(stat.Mode&^linux.S_IFMT))
needsCtimeBump = true
}
if mask&linux.STATX_UID != 0 {
@@ -433,6 +435,10 @@ func (i *inode) direntType() uint8 {
}
}
+func (i *inode) isDir() bool {
+ return linux.FileMode(i.mode).FileType() == linux.S_IFDIR
+}
+
// fileDescription is embedded by tmpfs implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
@@ -457,5 +463,6 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- return fd.inode().setStat(opts.Stat)
+ creds := auth.CredentialsFromContext(ctx)
+ return fd.inode().setStat(ctx, creds, &opts.Stat)
}
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
index c16667e7f..029af3025 100644
--- a/pkg/sentry/inet/namespace.go
+++ b/pkg/sentry/inet/namespace.go
@@ -23,7 +23,10 @@ type Namespace struct {
// creator allows kernel to create new network stack for network namespaces.
// If nil, no networking will function if network is namespaced.
- creator NetworkStackCreator
+ //
+ // At afterLoad(), creator will be used to create network stack. Stateify
+ // needs to wait for this field to be loaded before calling afterLoad().
+ creator NetworkStackCreator `state:"wait"`
// isRoot indicates whether this is the root network namespace.
isRoot bool
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 8bffb78fc..592650923 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -296,8 +296,10 @@ func (*readyCallback) Callback(w *waiter.Entry) {
e.waitingList.Remove(entry)
e.readyList.PushBack(entry)
entry.curList = &e.readyList
+ e.listsMu.Unlock()
e.Notify(waiter.EventIn)
+ return
}
e.listsMu.Unlock()
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
index a0d35d350..8e9f200d0 100644
--- a/pkg/sentry/kernel/epoll/epoll_state.go
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -38,11 +38,14 @@ func (e *EventPoll) afterLoad() {
}
}
- for it := e.waitingList.Front(); it != nil; it = it.Next() {
- if it.id.File.Readiness(it.mask) != 0 {
- e.waitingList.Remove(it)
- e.readyList.PushBack(it)
- it.curList = &e.readyList
+ for it := e.waitingList.Front(); it != nil; {
+ entry := it
+ it = it.Next()
+
+ if entry.id.File.Readiness(entry.mask) != 0 {
+ e.waitingList.Remove(entry)
+ e.readyList.PushBack(entry)
+ entry.curList = &e.readyList
e.Notify(waiter.EventIn)
}
}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 58001d56c..d09d97825 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -191,10 +191,12 @@ func (f *FDTable) Size() int {
return int(size)
}
-// forEach iterates over all non-nil files.
+// forEach iterates over all non-nil files in sorted order.
//
// It is the caller's responsibility to acquire an appropriate lock.
func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
+ // retries tracks the number of failed TryIncRef attempts for the same FD.
+ retries := 0
fd := int32(0)
for {
file, fileVFS2, flags, ok := f.getAll(fd)
@@ -204,17 +206,26 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes
switch {
case file != nil:
if !file.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, FileOps: %+v", fd, file, file.FileOperations))
+ }
continue // Race caught.
}
fn(fd, file, nil, flags)
file.DecRef()
case fileVFS2 != nil:
if !fileVFS2.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, Impl: %+v", fd, fileVFS2, fileVFS2.Impl()))
+ }
continue // Race caught.
}
fn(fd, nil, fileVFS2, flags)
fileVFS2.DecRef()
}
+ retries = 0
fd++
}
}
@@ -327,7 +338,7 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc
fd = f.next
}
for fd < end {
- if d, _, _ := f.get(fd); d == nil {
+ if d, _, _ := f.getVFS2(fd); d == nil {
f.setVFS2(fd, file, flags)
if fd == f.next {
// Update next search start position.
@@ -447,7 +458,10 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) {
}
}
-// GetFDs returns a list of valid fds.
+// GetFDs returns a sorted list of valid fds.
+//
+// Precondition: The caller must be running on the task goroutine, or Task.mu
+// must be locked.
func (f *FDTable) GetFDs() []int32 {
fds := make([]int32, 0, int(atomic.LoadInt32(&f.used)))
f.forEach(func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
@@ -522,7 +536,9 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) {
case orig2 != nil:
orig2.IncRef()
}
- f.setAll(fd, nil, nil, FDFlags{}) // Zap entry.
+ if orig != nil || orig2 != nil {
+ f.setAll(fd, nil, nil, FDFlags{}) // Zap entry.
+ }
return orig, orig2
}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8b76750e9..6feda8fa1 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -467,6 +467,11 @@ func (k *Kernel) flushMountSourceRefs() error {
//
// Precondition: Must be called with the kernel paused.
func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) (err error) {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return nil
+ }
+
ts.mu.RLock()
defer ts.mu.RUnlock()
for t := range ts.Root.tids {
@@ -484,7 +489,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error)
}
func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
- // TODO(gvisor.dev/issues/1663): Add save support for VFS2.
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
return ts.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
@@ -533,6 +538,11 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
}
func (ts *TaskSet) unregisterEpollWaiters() {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return
+ }
+
ts.mu.RLock()
defer ts.mu.RUnlock()
for t := range ts.Root.tids {
@@ -755,6 +765,8 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
@@ -1003,11 +1015,14 @@ func (k *Kernel) pauseTimeLocked() {
// This means we'll iterate FDTables shared by multiple tasks repeatedly,
// but ktime.Timer.Pause is idempotent so this is harmless.
if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
- tfd.PauseTimer()
- }
- })
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if !VFS2Enabled {
+ t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
+ tfd.PauseTimer()
+ }
+ })
+ }
}
}
k.timekeeper.PauseUpdates()
@@ -1032,12 +1047,15 @@ func (k *Kernel) resumeTimeLocked() {
it.ResumeTimer()
}
}
- if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
- tfd.ResumeTimer()
- }
- })
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if !VFS2Enabled {
+ if t.fdTable != nil {
+ t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
+ tfd.ResumeTimer()
+ }
+ })
+ }
}
}
}
@@ -1481,6 +1499,8 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 4c049d5b4..f29dc0472 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,25 +1,10 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-go_template_instance(
- name = "buffer_list",
- out = "buffer_list.go",
- package = "pipe",
- prefix = "buffer",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*buffer",
- "Linker": "*buffer",
- },
-)
-
go_library(
name = "pipe",
srcs = [
- "buffer.go",
- "buffer_list.go",
"device.go",
"node.go",
"pipe.go",
@@ -33,8 +18,8 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
+ "//pkg/buffer",
"//pkg/context",
- "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
@@ -51,7 +36,6 @@ go_test(
name = "pipe_test",
size = "small",
srcs = [
- "buffer_test.go",
"node_test.go",
"pipe_test.go",
],
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
deleted file mode 100644
index fe3be5dbd..000000000
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package pipe
-
-import (
- "io"
-
- "gvisor.dev/gvisor/pkg/safemem"
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-// buffer encapsulates a queueable byte buffer.
-//
-// Note that the total size is slightly less than two pages. This
-// is done intentionally to ensure that the buffer object aligns
-// with runtime internals. We have no hard size or alignment
-// requirements. This two page size will effectively minimize
-// internal fragmentation, but still have a large enough chunk
-// to limit excessive segmentation.
-//
-// +stateify savable
-type buffer struct {
- data [8144]byte
- read int
- write int
- bufferEntry
-}
-
-// Reset resets internal data.
-//
-// This must be called before use.
-func (b *buffer) Reset() {
- b.read = 0
- b.write = 0
-}
-
-// Empty indicates the buffer is empty.
-//
-// This indicates there is no data left to read.
-func (b *buffer) Empty() bool {
- return b.read == b.write
-}
-
-// Full indicates the buffer is full.
-//
-// This indicates there is no capacity left to write.
-func (b *buffer) Full() bool {
- return b.write == len(b.data)
-}
-
-// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
-func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.write:]))
- n, err := safemem.CopySeq(dst, srcs)
- b.write += int(n)
- return n, err
-}
-
-// WriteFromReader writes to the buffer from an io.Reader.
-func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
- dst := b.data[b.write:]
- if count < int64(len(dst)) {
- dst = b.data[b.write:][:count]
- }
- n, err := r.Read(dst)
- b.write += n
- return int64(n), err
-}
-
-// ReadToBlocks implements safemem.Reader.ReadToBlocks.
-func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
- n, err := safemem.CopySeq(dsts, src)
- b.read += int(n)
- return n, err
-}
-
-// ReadToWriter reads from the buffer into an io.Writer.
-func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
- src := b.data[b.read:b.write]
- if count < int64(len(src)) {
- src = b.data[b.read:][:count]
- }
- n, err := w.Write(src)
- if !dup {
- b.read += n
- }
- return int64(n), err
-}
-
-// bufferPool is a pool for buffers.
-var bufferPool = sync.Pool{
- New: func() interface{} {
- return new(buffer)
- },
-}
-
-// newBuffer grabs a new buffer from the pool.
-func newBuffer() *buffer {
- b := bufferPool.Get().(*buffer)
- b.Reset()
- return b
-}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 08410283f..725e9db7d 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -20,6 +20,7 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
@@ -70,10 +71,10 @@ type Pipe struct {
// mu protects all pipe internal state below.
mu sync.Mutex `state:"nosave"`
- // data is the buffer queue of pipe contents.
+ // view is the underlying set of buffers.
//
// This is protected by mu.
- data bufferList
+ view buffer.View
// max is the maximum size of the pipe in bytes. When this max has been
// reached, writers will get EWOULDBLOCK.
@@ -81,11 +82,6 @@ type Pipe struct {
// This is protected by mu.
max int64
- // size is the current size of the pipe in bytes.
- //
- // This is protected by mu.
- size int64
-
// hadWriter indicates if this pipe ever had a writer. Note that this
// does not necessarily indicate there is *currently* a writer, just
// that there has been a writer at some point since the pipe was
@@ -196,7 +192,7 @@ type readOps struct {
limit func(int64)
// read performs the actual read operation.
- read func(*buffer) (int64, error)
+ read func(*buffer.View) (int64, error)
}
// read reads data from the pipe into dst and returns the number of bytes
@@ -213,7 +209,7 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
defer p.mu.Unlock()
// Is the pipe empty?
- if p.size == 0 {
+ if p.view.Size() == 0 {
if !p.HasWriters() {
// There are no writers, return EOF.
return 0, nil
@@ -222,71 +218,13 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
}
// Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
+ if ops.left() > p.view.Size() {
+ ops.limit(p.view.Size())
}
- done := int64(0)
- for ops.left() > 0 {
- // Pop the first buffer.
- first := p.data.Front()
- if first == nil {
- break
- }
-
- // Copy user data.
- n, err := ops.read(first)
- done += int64(n)
- p.size -= n
-
- // Empty buffer?
- if first.Empty() {
- // Push to the free list.
- p.data.Remove(first)
- bufferPool.Put(first)
- }
-
- // Handle errors.
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
-}
-
-// dup duplicates all data from this pipe into the given writer.
-//
-// There is no blocking behavior implemented here. The writer may propagate
-// some blocking error. All the writes must be complete writes.
-func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- // Is the pipe empty?
- if p.size == 0 {
- if !p.HasWriters() {
- // See above.
- return 0, nil
- }
- return 0, syserror.ErrWouldBlock
- }
-
- // Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
- }
-
- done := int64(0)
- for buf := p.data.Front(); buf != nil; buf = buf.Next() {
- n, err := ops.read(buf)
- done += n
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
+ // Copy user data; the read op is responsible for trimming.
+ done, err := ops.read(&p.view)
+ return done, err
}
type writeOps struct {
@@ -297,7 +235,7 @@ type writeOps struct {
limit func(int64)
// write should write to the provided buffer.
- write func(*buffer) (int64, error)
+ write func(*buffer.View) (int64, error)
}
// write writes data from sv into the pipe and returns the number of bytes
@@ -317,33 +255,19 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
wanted := ops.left()
- if avail := p.max - p.size; wanted > avail {
+ if avail := p.max - p.view.Size(); wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
ops.limit(avail)
}
- done := int64(0)
- for ops.left() > 0 {
- // Need a new buffer?
- last := p.data.Back()
- if last == nil || last.Full() {
- // Add a new buffer to the data list.
- last = newBuffer()
- p.data.PushBack(last)
- }
-
- // Copy user data.
- n, err := ops.write(last)
- done += int64(n)
- p.size += n
-
- // Handle errors.
- if err != nil {
- return done, err
- }
+ // Copy user data.
+ done, err := ops.write(&p.view)
+ if err != nil {
+ return done, err
}
+
if wanted > done {
// Partial write due to full pipe.
return done, syserror.ErrWouldBlock
@@ -396,7 +320,7 @@ func (p *Pipe) HasWriters() bool {
// Precondition: mu must be held.
func (p *Pipe) rReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasReaders() && p.data.Front() != nil {
+ if p.HasReaders() && p.view.Size() != 0 {
ready |= waiter.EventIn
}
if !p.HasWriters() && p.hadWriter {
@@ -422,7 +346,7 @@ func (p *Pipe) rReadiness() waiter.EventMask {
// Precondition: mu must be held.
func (p *Pipe) wReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasWriters() && p.size < p.max {
+ if p.HasWriters() && p.view.Size() < p.max {
ready |= waiter.EventOut
}
if !p.HasReaders() {
@@ -451,7 +375,7 @@ func (p *Pipe) rwReadiness() waiter.EventMask {
func (p *Pipe) queued() int64 {
p.mu.Lock()
defer p.mu.Unlock()
- return p.size
+ return p.view.Size()
}
// FifoSize implements fs.FifoSizer.FifoSize.
@@ -474,7 +398,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) {
}
p.mu.Lock()
defer p.mu.Unlock()
- if size < p.size {
+ if size < p.view.Size() {
return 0, syserror.EBUSY
}
p.max = size
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 80158239e..5a1d4fd57 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
@@ -49,9 +50,10 @@ func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error)
limit: func(l int64) {
dst = dst.TakeFirst64(l)
},
- read: func(buf *buffer) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, buf)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, view)
dst = dst.DropFirst64(n)
+ view.TrimFront(n)
return n, err
},
})
@@ -70,16 +72,15 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool)
limit: func(l int64) {
count = l
},
- read: func(buf *buffer) (int64, error) {
- n, err := buf.ReadToWriter(w, count, dup)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToWriter(w, count)
+ if !dup {
+ view.TrimFront(n)
+ }
count -= n
return n, err
},
}
- if dup {
- // There is no notification for dup operations.
- return p.dup(ctx, ops)
- }
n, err := p.read(ctx, ops)
if n > 0 {
p.Notify(waiter.EventOut)
@@ -96,8 +97,8 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error)
limit: func(l int64) {
src = src.TakeFirst64(l)
},
- write: func(buf *buffer) (int64, error) {
- n, err := src.CopyInTo(ctx, buf)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := src.CopyInTo(ctx, view)
src = src.DropFirst64(n)
return n, err
},
@@ -117,8 +118,8 @@ func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, e
limit: func(l int64) {
count = l
},
- write: func(buf *buffer) (int64, error) {
- n, err := buf.WriteFromReader(r, count)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromReader(r, count)
count -= n
return n, err
},
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 1000f3287..c00fa1138 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -554,6 +554,7 @@ func (s *sem) wakeWaiters() {
for w := s.waiters.Front(); w != nil; {
if s.value < w.value {
// Still blocked, skip it.
+ w = w.Next()
continue
}
w.ch <- struct{}{}
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 047b5214d..0e19286de 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -246,7 +246,7 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
var lastErr error
for tg := range tasks.Root.tgids {
- if tg.ProcessGroup() == pg {
+ if tg.processGroup == pg {
tg.signalHandlers.mu.Lock()
infoCopy := *info
if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil {
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 2cee2e6ed..8452ddf5b 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -37,6 +37,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -847,3 +848,18 @@ func (t *Task) AbstractSockets() *AbstractSocketNamespace {
func (t *Task) ContainerID() string {
return t.containerID
}
+
+// OOMScoreAdj gets the task's thread group's OOM score adjustment.
+func (t *Task) OOMScoreAdj() int32 {
+ return atomic.LoadInt32(&t.tg.oomScoreAdj)
+}
+
+// SetOOMScoreAdj sets the task's thread group's OOM score adjustment. The
+// value should be between -1000 and 1000 inclusive.
+func (t *Task) SetOOMScoreAdj(adj int32) error {
+ if adj > 1000 || adj < -1000 {
+ return syserror.EINVAL
+ }
+ atomic.StoreInt32(&t.tg.oomScoreAdj, adj)
+ return nil
+}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 78866f280..e1ecca99e 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -15,6 +15,8 @@
package kernel
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -260,6 +262,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
sh = sh.Fork()
}
tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+ tg.oomScoreAdj = atomic.LoadInt32(&t.tg.oomScoreAdj)
rseqAddr = t.rseqAddr
rseqSignature = t.rseqSignature
}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index 5568c91bc..799cbcd93 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -126,13 +126,39 @@ func (t *Task) doStop() {
}
}
+func (*runApp) handleCPUIDInstruction(t *Task) error {
+ if len(arch.CPUIDInstruction) == 0 {
+ // CPUID emulation isn't supported, but this code can be
+ // executed, because the ptrace platform returns
+ // ErrContextSignalCPUID on page faults too. Look at
+ // pkg/sentry/platform/ptrace/ptrace.go:context.Switch for more
+ // details.
+ return platform.ErrContextSignal
+ }
+ // Is this a CPUID instruction?
+ region := trace.StartRegion(t.traceContext, cpuidRegion)
+ expected := arch.CPUIDInstruction[:]
+ found := make([]byte, len(expected))
+ _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ if err == nil && bytes.Equal(expected, found) {
+ // Skip the cpuid instruction.
+ t.Arch().CPUIDEmulate(t)
+ t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+ region.End()
+
+ return nil
+ }
+ region.End() // Not an actual CPUID, but required copy-in.
+ return platform.ErrContextSignal
+}
+
// The runApp state checks for interrupts before executing untrusted
// application code.
//
// +stateify savable
type runApp struct{}
-func (*runApp) execute(t *Task) taskRunState {
+func (app *runApp) execute(t *Task) taskRunState {
if t.interrupted() {
// Checkpointing instructs tasks to stop by sending an interrupt, so we
// must check for stops before entering runInterrupt (instead of
@@ -237,21 +263,10 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextSignalCPUID:
- // Is this a CPUID instruction?
- region := trace.StartRegion(t.traceContext, cpuidRegion)
- expected := arch.CPUIDInstruction[:]
- found := make([]byte, len(expected))
- _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
- if err == nil && bytes.Equal(expected, found) {
- // Skip the cpuid instruction.
- t.Arch().CPUIDEmulate(t)
- t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
- region.End()
-
+ if err := app.handleCPUIDInstruction(t); err == nil {
// Resume execution.
return (*runApp)(nil)
}
- region.End() // Not an actual CPUID, but required copy-in.
// The instruction at the given RIP was not a CPUID, and we
// fallthrough to the default signal deliver behavior below.
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 268f62e9d..52849f5b3 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -254,6 +254,13 @@ type ThreadGroup struct {
//
// tty is protected by the signal mutex.
tty *TTY
+
+ // oomScoreAdj is the thread group's OOM score adjustment. This is
+ // currently not used but is maintained for consistency.
+ // TODO(gvisor.dev/issue/1967)
+ //
+ // oomScoreAdj is accessed using atomic memory operations.
+ oomScoreAdj int32
}
// NewThreadGroup returns a new, empty thread group in PID namespace pidns. The
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index 35cd55fef..4b23f7803 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -81,12 +81,6 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
// Save the death message, which will be thrown.
c.dieState.message = msg
- // Reload all registers to have an accurate stack trace when we return
- // to host mode. This means that the stack should be unwound correctly.
- if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
- throw(msg)
- }
-
// Setup the trampoline.
dieArchSetup(c, context, &c.dieState.guestRegs)
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index a63a6a071..99cac665d 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -31,6 +31,12 @@ import (
//
//go:nosplit
func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // Reload all registers to have an accurate stack trace when we return
+ // to host mode. This means that the stack should be unwound correctly.
+ if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
+ throw(c.dieState.message)
+ }
+
// If the vCPU is in user mode, we set the stack to the stored stack
// value in the vCPU itself. We don't want to unwind the user stack.
if guestRegs.RFLAGS&ring0.UserFlagsSet == ring0.UserFlagsSet {
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index c61700892..04efa0147 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -82,6 +82,8 @@ fallback:
// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
- // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64.
- MOVD R9, 8(RSP)
- BL ·dieHandler(SB)
+ // R0: Fake the old PC as caller
+ // R1: First argument (vCPU)
+ MOVD.P R1, 8(RSP) // R1: First argument (vCPU)
+ MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller
+ B ·dieHandler(SB)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 2f02c03cf..eb5ed574e 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -17,10 +17,33 @@
package kvm
import (
+ "unsafe"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
+// dieArchSetup initialies the state for dieTrampoline.
+//
+// The arm64 dieTrampoline requires the vCPU to be set in R1, and the last PC
+// to be in R0. The trampoline then simulates a call to dieHandler from the
+// provided PC.
+//
//go:nosplit
func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
- // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64.
+ // If the vCPU is in user mode, we set the stack to the stored stack
+ // value in the vCPU itself. We don't want to unwind the user stack.
+ if guestRegs.Regs.Pstate&ring0.PSR_MODE_MASK == ring0.PSR_MODE_EL0t {
+ regs := c.CPU.Registers()
+ context.Regs[0] = regs.Regs[0]
+ context.Sp = regs.Sp
+ context.Regs[29] = regs.Regs[29] // stack base address
+ } else {
+ context.Regs[0] = guestRegs.Regs.Pc
+ context.Sp = guestRegs.Regs.Sp
+ context.Regs[29] = guestRegs.Regs.Regs[29]
+ context.Pstate = guestRegs.Regs.Pstate
+ }
+ context.Regs[1] = uint64(uintptr(unsafe.Pointer(c)))
+ context.Pc = uint64(dieTrampolineAddr)
}
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index 972ba85c3..a9b4af43e 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -27,6 +27,38 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// userMemoryRegion is a region of physical memory.
+//
+// This mirrors kvm_memory_region.
+type userMemoryRegion struct {
+ slot uint32
+ flags uint32
+ guestPhysAddr uint64
+ memorySize uint64
+ userspaceAddr uint64
+}
+
+// runData is the run structure. This may be mapped for synchronous register
+// access (although that doesn't appear to be supported by my kernel at least).
+//
+// This mirrors kvm_run.
+type runData struct {
+ requestInterruptWindow uint8
+ _ [7]uint8
+
+ exitReason uint32
+ readyForInterruptInjection uint8
+ ifFlag uint8
+ _ [2]uint8
+
+ cr8 uint64
+ apicBase uint64
+
+ // This is the union data for exits. Interpretation depends entirely on
+ // the exitReason above (see vCPU code for more information).
+ data [32]uint64
+}
+
// KVM represents a lightweight VM context.
type KVM struct {
platform.NoCPUPreemptionDetection
diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go
index c5a6f9c7d..093497bc4 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64.go
@@ -21,17 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
-// userMemoryRegion is a region of physical memory.
-//
-// This mirrors kvm_memory_region.
-type userMemoryRegion struct {
- slot uint32
- flags uint32
- guestPhysAddr uint64
- memorySize uint64
- userspaceAddr uint64
-}
-
// userRegs represents KVM user registers.
//
// This mirrors kvm_regs.
@@ -169,27 +158,6 @@ type modelControlRegisters struct {
entries [16]modelControlRegister
}
-// runData is the run structure. This may be mapped for synchronous register
-// access (although that doesn't appear to be supported by my kernel at least).
-//
-// This mirrors kvm_run.
-type runData struct {
- requestInterruptWindow uint8
- _ [7]uint8
-
- exitReason uint32
- readyForInterruptInjection uint8
- ifFlag uint8
- _ [2]uint8
-
- cr8 uint64
- apicBase uint64
-
- // This is the union data for exits. Interpretation depends entirely on
- // the exitReason above (see vCPU code for more information).
- data [32]uint64
-}
-
// cpuidEntry is a single CPUID entry.
//
// This mirrors kvm_cpuid_entry2.
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 2319c86d3..79045651e 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -20,17 +20,6 @@ import (
"syscall"
)
-// userMemoryRegion is a region of physical memory.
-//
-// This mirrors kvm_memory_region.
-type userMemoryRegion struct {
- slot uint32
- flags uint32
- guestPhysAddr uint64
- memorySize uint64
- userspaceAddr uint64
-}
-
type kvmOneReg struct {
id uint64
addr uint64
@@ -53,27 +42,6 @@ type userRegs struct {
fpRegs userFpsimdState
}
-// runData is the run structure. This may be mapped for synchronous register
-// access (although that doesn't appear to be supported by my kernel at least).
-//
-// This mirrors kvm_run.
-type runData struct {
- requestInterruptWindow uint8
- _ [7]uint8
-
- exitReason uint32
- readyForInterruptInjection uint8
- ifFlag uint8
- _ [2]uint8
-
- cr8 uint64
- apicBase uint64
-
- // This is the union data for exits. Interpretation depends entirely on
- // the exitReason above (see vCPU code for more information).
- data [32]uint64
-}
-
// updateGlobalOnce does global initialization. It has to be called only once.
func updateGlobalOnce(fd int) error {
physicalInit()
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 1c8384e6b..b531f2f85 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -29,30 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// setMemoryRegion initializes a region.
-//
-// This may be called from bluepillHandler, and therefore returns an errno
-// directly (instead of wrapping in an error) to avoid allocations.
-//
-//go:nosplit
-func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
- userRegion := userMemoryRegion{
- slot: uint32(slot),
- flags: 0,
- guestPhysAddr: uint64(physical),
- memorySize: uint64(length),
- userspaceAddr: uint64(virtual),
- }
-
- // Set the region.
- _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(m.fd),
- _KVM_SET_USER_MEMORY_REGION,
- uintptr(unsafe.Pointer(&userRegion)))
- return errno
-}
-
type kvmVcpuInit struct {
target uint32
features [7]uint32
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 95abd321e..30402c2df 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -9,6 +9,7 @@ go_library(
"ptrace.go",
"ptrace_amd64.go",
"ptrace_arm64.go",
+ "ptrace_arm64_unsafe.go",
"ptrace_unsafe.go",
"stub_amd64.s",
"stub_arm64.s",
diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go
index db0212538..24fc5dc62 100644
--- a/pkg/sentry/platform/ptrace/ptrace_amd64.go
+++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go
@@ -31,3 +31,17 @@ func fpRegSet(useXsave bool) uintptr {
func stackPointer(r *syscall.PtraceRegs) uintptr {
return uintptr(r.Rsp)
}
+
+// x86 use the fs_base register to store the TLS pointer which can be
+// get/set in "func (t *thread) get/setRegs(regs *syscall.PtraceRegs)".
+// So both of the get/setTLS() operations are noop here.
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
new file mode 100644
index 000000000..32b8a6be9
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ptrace
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 31b7cec53..a644609ef 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -506,6 +506,9 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
regs := &ac.StateData().Regs
t.resetSysemuRegs(regs)
+ // Extract TLS register
+ tls := uint64(ac.TLS())
+
// Check for interrupts, and ensure that future interrupts will signal t.
if !c.interrupt.Enable(t) {
// Pending interrupt; simulate.
@@ -526,6 +529,9 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if err := t.setFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
panic(fmt.Sprintf("ptrace set fpregs (%+v) failed: %v", fpState, err))
}
+ if err := t.setTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace set tls (%+v) failed: %v", tls, err))
+ }
for {
// Start running until the next system call.
@@ -555,6 +561,12 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if err := t.getFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
panic(fmt.Sprintf("ptrace get fpregs failed: %v", err))
}
+ if err := t.getTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace get tls failed: %v", err))
+ }
+ if !ac.SetTLS(uintptr(tls)) {
+ panic(fmt.Sprintf("tls value %v is invalid", tls))
+ }
// Is it a system call?
if sig == (syscallEvent | syscall.SIGTRAP) {
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index e99798c56..cd74945e7 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -21,6 +21,7 @@ import (
"strings"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -183,13 +184,76 @@ func enableCpuidFault() {
// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
// Ref attachedThread() for more detail.
-func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
- return append(rules, seccomp.RuleSet{
- Rules: seccomp.SyscallRules{
- syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
+ rules = append(rules,
+ // Rules for trapping vsyscall access.
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_TIME: {},
+ unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
},
- },
- Action: linux.SECCOMP_RET_ALLOW,
- })
+ Action: linux.SECCOMP_RET_TRAP,
+ Vsyscall: true,
+ })
+ if defaultAction != linux.SECCOMP_RET_ALLOW {
+ rules = append(rules,
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+ }
+ return rules
+}
+
+// probeSeccomp returns true iff seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8. This check is dynamic
+// because kernels have be backported behavior.
+//
+// See createStub for more information.
+//
+// Precondition: the runtime OS thread must be locked.
+func probeSeccomp() bool {
+ // Create a completely new, destroyable process.
+ t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
+ if err != nil {
+ panic(fmt.Sprintf("seccomp probe failed: %v", err))
+ }
+ defer t.destroy()
+
+ // Set registers to the yield system call. This call is not allowed
+ // by the filters specified in the attachThread function.
+ regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
+ if err := t.setRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Attempt an emulation.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Did the seccomp errno hook already run? This would
+ // indicate that seccomp is first in line and we're
+ // less than 4.8.
+ if err := t.getRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
+ }
+ if _, err := syscallReturnValue(&regs); err == nil {
+ // The seccomp errno mode ran first, and reset
+ // the error in the registers.
+ return false
+ }
+ // The seccomp hook did not run yet, and therefore it
+ // is safe to use RET_KILL mode for dispatched calls.
+ return true
+ }
+ }
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index 7b975137f..7f5c393f0 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -160,6 +160,15 @@ func enableCpuidFault() {
// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
// Ref attachedThread() for more detail.
-func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
return rules
}
+
+// probeSeccomp returns true if seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8.
+//
+// On arm64, the support of PTRACE_SYSEMU was added in the 5.3 kernel, so
+// probeSeccomp can always return true.
+func probeSeccomp() bool {
+ return true
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 74968dfdf..2ce528601 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -20,7 +20,6 @@ import (
"fmt"
"syscall"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
@@ -30,54 +29,6 @@ import (
const syscallEvent syscall.Signal = 0x80
-// probeSeccomp returns true iff seccomp is run after ptrace notifications,
-// which is generally the case for kernel version >= 4.8. This check is dynamic
-// because kernels have be backported behavior.
-//
-// See createStub for more information.
-//
-// Precondition: the runtime OS thread must be locked.
-func probeSeccomp() bool {
- // Create a completely new, destroyable process.
- t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
- if err != nil {
- panic(fmt.Sprintf("seccomp probe failed: %v", err))
- }
- defer t.destroy()
-
- // Set registers to the yield system call. This call is not allowed
- // by the filters specified in the attachThread function.
- regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
- if err := t.setRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace set regs failed: %v", err))
- }
-
- for {
- // Attempt an emulation.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
- // Did the seccomp errno hook already run? This would
- // indicate that seccomp is first in line and we're
- // less than 4.8.
- if err := t.getRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
- }
- if _, err := syscallReturnValue(&regs); err == nil {
- // The seccomp errno mode ran first, and reset
- // the error in the registers.
- return false
- }
- // The seccomp hook did not run yet, and therefore it
- // is safe to use RET_KILL mode for dispatched calls.
- return true
- }
- }
-}
-
// createStub creates a fresh stub processes.
//
// Precondition: the runtime OS thread must be locked.
@@ -123,18 +74,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// stub and all its children. This is used to create child stubs
// (below), so we must include the ability to fork, but otherwise lock
// down available calls only to what is needed.
- rules := []seccomp.RuleSet{
- // Rules for trapping vsyscall access.
- {
- Rules: seccomp.SyscallRules{
- syscall.SYS_GETTIMEOFDAY: {},
- syscall.SYS_TIME: {},
- unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
- },
- Action: linux.SECCOMP_RET_TRAP,
- Vsyscall: true,
- },
- }
+ rules := []seccomp.RuleSet{}
if defaultAction != linux.SECCOMP_RET_ALLOW {
rules = append(rules, seccomp.RuleSet{
Rules: seccomp.SyscallRules{
@@ -173,9 +113,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
},
Action: linux.SECCOMP_RET_ALLOW,
})
-
- rules = appendArchSeccompRules(rules)
}
+ rules = appendArchSeccompRules(rules, defaultAction)
instrs, err := seccomp.BuildProgram(rules, defaultAction)
if err != nil {
return nil, err
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index f6da41c27..8122ac6e2 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -27,26 +27,27 @@ const (
_PTE_PGT_BASE = 0x7000
_PTE_PGT_SIZE = 0x1000
- _PSR_MODE_EL0t = 0x0
- _PSR_MODE_EL1t = 0x4
- _PSR_MODE_EL1h = 0x5
- _PSR_EL_MASK = 0xf
-
- _PSR_D_BIT = 0x200
- _PSR_A_BIT = 0x100
- _PSR_I_BIT = 0x80
- _PSR_F_BIT = 0x40
+ _PSR_D_BIT = 0x00000200
+ _PSR_A_BIT = 0x00000100
+ _PSR_I_BIT = 0x00000080
+ _PSR_F_BIT = 0x00000040
)
const (
+ // PSR bits
+ PSR_MODE_EL0t = 0x00000000
+ PSR_MODE_EL1t = 0x00000004
+ PSR_MODE_EL1h = 0x00000005
+ PSR_MODE_MASK = 0x0000000f
+
// KernelFlagsSet should always be set in the kernel.
- KernelFlagsSet = _PSR_MODE_EL1h
+ KernelFlagsSet = PSR_MODE_EL1h
// UserFlagsSet are always set in userspace.
- UserFlagsSet = _PSR_MODE_EL0t
+ UserFlagsSet = PSR_MODE_EL0t
- KernelFlagsClear = _PSR_EL_MASK
- UserFlagsClear = _PSR_EL_MASK
+ KernelFlagsClear = PSR_MODE_MASK
+ UserFlagsClear = PSR_MODE_MASK
PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT
)
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 4f2406ce3..581841555 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -80,7 +80,7 @@ go_library(
"pagetables_amd64.go",
"pagetables_arm64.go",
"pagetables_x86.go",
- "pcids_x86.go",
+ "pcids.go",
"walker_amd64.go",
"walker_arm64.go",
"walker_empty.go",
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids.go
index e199bae18..9206030bf 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
-
package pagetables
import (
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go
index ba1f9043d..83195d5a1 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sentry/sighandling/sighandling.go
@@ -85,6 +85,11 @@ func StartSignalForwarding(handler func(linux.Signal)) func() {
for sig := 1; sig <= numSignals+1; sig++ {
sigchan := make(chan os.Signal, 1)
sigchans = append(sigchans, sigchan)
+
+ // SIGURG is used by Go's runtime scheduler.
+ if sig == int(linux.SIGURG) {
+ continue
+ }
signal.Notify(sigchan, syscall.Signal(sig))
}
// Start up our listener.
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 7cd2ce55b..e801abeb8 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -22,7 +22,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/usermem",
],
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index b4b244abf..0336a32d8 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,7 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,12 +37,12 @@ type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
name() string
- // marshal converts from an iptables.Matcher to an ABI struct.
- marshal(matcher iptables.Matcher) []byte
+ // marshal converts from an stack.Matcher to an ABI struct.
+ marshal(matcher stack.Matcher) []byte
// unmarshal converts from the ABI matcher struct to an
- // iptables.Matcher.
- unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error)
+ // stack.Matcher.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
}
// matchMakers maps the name of supported matchers to the matchMaker that
@@ -58,7 +58,7 @@ func registerMatchMaker(mm matchMaker) {
matchMakers[mm.name()] = mm
}
-func marshalMatcher(matcher iptables.Matcher) []byte {
+func marshalMatcher(matcher stack.Matcher) []byte {
matchMaker, ok := matchMakers[matcher.Name()]
if !ok {
panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
@@ -86,7 +86,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
return append(buf, make([]byte, size-len(buf))...)
}
-func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) {
+func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
matchMaker, ok := matchMakers[match.Name.String()]
if !ok {
return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 2ec11f6ac..55bcc3ace 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -26,7 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -35,6 +35,11 @@ import (
// shouldn't be reached - an error has occurred if we fall through to one.
const errorTargetName = "ERROR"
+// redirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const redirectTargetName = "REDIRECT"
+
// Metadata is used to verify that we are correctly serializing and
// deserializing iptables into structs consumable by the iptables tool. We save
// a metadata struct when the tables are written, and when they are read out we
@@ -123,19 +128,19 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
return entries, nil
}
-func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) {
- ipt := stack.IPTables()
+func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) {
+ ipt := stk.IPTables()
table, ok := ipt.Tables[tablename.String()]
if !ok {
- return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename)
+ return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
// FillDefaultIPTables sets stack's IPTables to the default tables and
// populates them with metadata.
-func FillDefaultIPTables(stack *stack.Stack) {
- ipt := iptables.DefaultTables()
+func FillDefaultIPTables(stk *stack.Stack) {
+ ipt := stack.DefaultTables()
// In order to fill in the metadata, we have to translate ipt from its
// netstack format to Linux's giant-binary-blob format.
@@ -148,14 +153,14 @@ func FillDefaultIPTables(stack *stack.Stack) {
ipt.Tables[name] = table
}
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
}
// convertNetstackToBinary converts the iptables as stored in netstack to the
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) {
+func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
@@ -228,18 +233,20 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern
return entries, meta, nil
}
-func marshalTarget(target iptables.Target) []byte {
+func marshalTarget(target stack.Target) []byte {
switch tg := target.(type) {
- case iptables.AcceptTarget:
- return marshalStandardTarget(iptables.RuleAccept)
- case iptables.DropTarget:
- return marshalStandardTarget(iptables.RuleDrop)
- case iptables.ErrorTarget:
+ case stack.AcceptTarget:
+ return marshalStandardTarget(stack.RuleAccept)
+ case stack.DropTarget:
+ return marshalStandardTarget(stack.RuleDrop)
+ case stack.ErrorTarget:
return marshalErrorTarget(errorTargetName)
- case iptables.UserChainTarget:
+ case stack.UserChainTarget:
return marshalErrorTarget(tg.Name)
- case iptables.ReturnTarget:
- return marshalStandardTarget(iptables.RuleReturn)
+ case stack.ReturnTarget:
+ return marshalStandardTarget(stack.RuleReturn)
+ case stack.RedirectTarget:
+ return marshalRedirectTarget()
case JumpTarget:
return marshalJumpTarget(tg)
default:
@@ -247,7 +254,7 @@ func marshalTarget(target iptables.Target) []byte {
}
}
-func marshalStandardTarget(verdict iptables.RuleVerdict) []byte {
+func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
@@ -276,6 +283,19 @@ func marshalErrorTarget(errorName string) []byte {
return binary.Marshal(ret, usermem.ByteOrder, target)
}
+func marshalRedirectTarget() []byte {
+ // This is a redirect target named redirect
+ target := linux.XTRedirectTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTRedirectTarget,
+ },
+ }
+ copy(target.Target.Name[:], redirectTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
func marshalJumpTarget(jt JumpTarget) []byte {
nflog("convert to binary: marshalling jump target")
@@ -295,13 +315,13 @@ func marshalJumpTarget(jt JumpTarget) []byte {
// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
+func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
switch verdict {
- case iptables.RuleAccept:
+ case stack.RuleAccept:
return -linux.NF_ACCEPT - 1
- case iptables.RuleDrop:
+ case stack.RuleDrop:
return -linux.NF_DROP - 1
- case iptables.RuleReturn:
+ case stack.RuleReturn:
return linux.NF_RETURN
default:
// TODO(gvisor.dev/issue/170): Support Jump.
@@ -310,18 +330,18 @@ func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
}
// translateToStandardTarget translates from the value in a
-// linux.XTStandardTarget to an iptables.Verdict.
-func translateToStandardTarget(val int32) (iptables.Target, error) {
+// linux.XTStandardTarget to an stack.Verdict.
+func translateToStandardTarget(val int32) (stack.Target, error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return iptables.AcceptTarget{}, nil
+ return stack.AcceptTarget{}, nil
case -linux.NF_DROP - 1:
- return iptables.DropTarget{}, nil
+ return stack.DropTarget{}, nil
case -linux.NF_QUEUE - 1:
return nil, errors.New("unsupported iptables verdict QUEUE")
case linux.NF_RETURN:
- return iptables.ReturnTarget{}, nil
+ return stack.ReturnTarget{}, nil
default:
return nil, fmt.Errorf("unknown iptables verdict %d", val)
}
@@ -329,7 +349,7 @@ func translateToStandardTarget(val int32) (iptables.Target, error) {
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
+func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Get the basic rules data (struct ipt_replace).
if len(optVal) < linux.SizeOfIPTReplace {
nflog("optVal has insufficient size for replace %d", len(optVal))
@@ -341,10 +361,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
// TODO(gvisor.dev/issue/170): Support other tables.
- var table iptables.Table
+ var table stack.Table
switch replace.Name.String() {
- case iptables.TablenameFilter:
- table = iptables.EmptyFilterTable()
+ case stack.TablenameFilter:
+ table = stack.EmptyFilterTable()
+ case stack.TablenameNat:
+ table = stack.EmptyNatTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -404,14 +426,14 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return syserr.ErrInvalidArgument
}
- target, err := parseTarget(optVal[:targetSize])
+ target, err := parseTarget(filter, optVal[:targetSize])
if err != nil {
nflog("failed to parse target: %v", err)
return syserr.ErrInvalidArgument
}
optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, iptables.Rule{
+ table.Rules = append(table.Rules, stack.Rule{
Filter: filter,
Target: target,
Matchers: matchers,
@@ -442,11 +464,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
table.Underflows[hk] = ruleIdx
}
}
- if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset {
nflog("hook %v is unset.", hk)
return syserr.ErrInvalidArgument
}
- if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset {
nflog("underflow %v is unset.", hk)
return syserr.ErrInvalidArgument
}
@@ -455,7 +477,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Add the user chains.
for ruleIdx, rule := range table.Rules {
- target, ok := rule.Target.(iptables.UserChainTarget)
+ target, ok := rule.Target.(stack.UserChainTarget)
if !ok {
continue
}
@@ -495,11 +517,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT chain right now, make sure
- // all other chains point to ACCEPT rules.
+ // Since we only support modifying the INPUT chain and redirect for
+ // PREROUTING chain right now, make sure all other chains point to
+ // ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook != iptables.Input {
- if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok {
+ if hook != stack.Input && hook != stack.Prerouting {
+ if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
@@ -511,7 +534,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- ipt := stack.IPTables()
+ ipt := stk.IPTables()
table.SetMetadata(metadata{
HookEntry: replace.HookEntry,
Underflow: replace.Underflow,
@@ -519,16 +542,16 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
Size: replace.Size,
})
ipt.Tables[replace.Name.String()] = table
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
return nil
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
// only the matchers.
-func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) {
+func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
nflog("set entries: parsing matchers of size %d", len(optVal))
- var matchers []iptables.Matcher
+ var matchers []stack.Matcher
for len(optVal) > 0 {
nflog("set entries: optVal has len %d", len(optVal))
@@ -570,7 +593,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
// parseTarget parses a target from optVal. optVal should contain only the
// target.
-func parseTarget(optVal []byte) (iptables.Target, error) {
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
@@ -614,67 +637,125 @@ func parseTarget(optVal []byte) (iptables.Target, error) {
switch name := errorTarget.Name.String(); name {
case errorTargetName:
nflog("set entries: error target")
- return iptables.ErrorTarget{}, nil
+ return stack.ErrorTarget{}, nil
default:
// User defined chain.
nflog("set entries: user-defined target %q", name)
- return iptables.UserChainTarget{Name: name}, nil
+ return stack.UserChainTarget{Name: name}, nil
+ }
+
+ case redirectTargetName:
+ // Redirect target.
+ if len(optVal) < linux.SizeOfXTRedirectTarget {
+ return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
}
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = optVal[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ var target stack.RedirectTarget
+ nfRange := redirectTarget.NfRange
+
+ // RangeSize should be 1.
+ if nfRange.RangeSize != 1 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+ target.RangeProtoSpecified = true
+
+ target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // Convert port from big endian to little endian.
+ port := make([]byte, 2)
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
+ target.MinPort = binary.LittleEndian.Uint16(port)
+
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
+ target.MaxPort = binary.LittleEndian.Uint16(port)
+ return target, nil
}
// Unknown target.
return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
}
-func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) {
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
if containsUnsupportedFields(iptip) {
- return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
}
- return iptables.IPHeaderFilter{
- Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
}, nil
}
func containsUnsupportedFields(iptip linux.IPTIP) bool {
- // Currently we check that everything except protocol is zeroed.
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - The inverse destination IP check flag
var emptyInetAddr = linux.InetAddr{}
var emptyInterface = [linux.IFNAMSIZ]byte{}
- return iptip.Dst != emptyInetAddr ||
- iptip.Src != emptyInetAddr ||
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP)
+ return iptip.Src != emptyInetAddr ||
iptip.SrcMask != emptyInetAddr ||
- iptip.DstMask != emptyInetAddr ||
iptip.InputInterface != emptyInterface ||
iptip.OutputInterface != emptyInterface ||
iptip.InputInterfaceMask != emptyInterface ||
iptip.OutputInterfaceMask != emptyInterface ||
iptip.Flags != 0 ||
- iptip.InverseFlags != 0
+ iptip.InverseFlags&^inverseMask != 0
}
-func validUnderflow(rule iptables.Rule) bool {
+func validUnderflow(rule stack.Rule) bool {
if len(rule.Matchers) != 0 {
return false
}
switch rule.Target.(type) {
- case iptables.AcceptTarget, iptables.DropTarget:
+ case stack.AcceptTarget, stack.DropTarget:
return true
default:
return false
}
}
-func hookFromLinux(hook int) iptables.Hook {
+func hookFromLinux(hook int) stack.Hook {
switch hook {
case linux.NF_INET_PRE_ROUTING:
- return iptables.Prerouting
+ return stack.Prerouting
case linux.NF_INET_LOCAL_IN:
- return iptables.Input
+ return stack.Input
case linux.NF_INET_FORWARD:
- return iptables.Forward
+ return stack.Forward
case linux.NF_INET_LOCAL_OUT:
- return iptables.Output
+ return stack.Output
case linux.NF_INET_POST_ROUTING:
- return iptables.Postrouting
+ return stack.Postrouting
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index c421b87cf..c948de876 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,11 +15,10 @@
package netfilter
import (
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// JumpTarget implements iptables.Target.
+// JumpTarget implements stack.Target.
type JumpTarget struct {
// Offset is the byte offset of the rule to jump to. It is used for
// marshaling and unmarshaling.
@@ -29,7 +28,7 @@ type JumpTarget struct {
RuleNum int
}
-// Action implements iptables.Target.Action.
-func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) {
- return iptables.RuleJump, jt.RuleNum
+// Action implements stack.Target.Action.
+func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) {
+ return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index f9945e214..ff1cfd8f6 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (tcpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (tcpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*TCPMatcher)
xttcp := linux.XTTCP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTTCP {
return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
}
@@ -97,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
@@ -115,7 +114,7 @@ func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var tcpHeader header.TCP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 86aa11696..3359418c1 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (udpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (udpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*UDPMatcher)
xtudp := linux.XTUDP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTUDP {
return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
}
@@ -94,11 +93,11 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the iptables.Check codepath as matchers are added.
+ // into the stack.Check codepath as matchers are added.
if netHeader.TransportProtocol() != header.UDPProtocolNumber {
return false, false
}
@@ -114,7 +113,7 @@ func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var udpHeader header.UDP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ab01cb4fa..cbf46b1e9 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -38,7 +38,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e187276c5..f14c336b9 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -29,6 +29,7 @@ import (
"io"
"math"
"reflect"
+ "sync/atomic"
"syscall"
"time"
@@ -264,6 +265,12 @@ type SocketOperations struct {
skType linux.SockType
protocol int
+ // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
+ // Must be accessed using atomic operations. It must only be written
+ // with readMu held but can be read without holding readMu. The latter
+ // is required to avoid deadlocks in epoll Readiness checks.
+ readViewHasData uint32
+
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
// readView contains the remaining payload from the last packet.
@@ -410,21 +417,24 @@ func (s *SocketOperations) isPacketBased() bool {
// fetchReadView updates the readView field of the socket if it's currently
// empty. It assumes that the socket is locked.
+//
+// Precondition: s.readMu must be held.
func (s *SocketOperations) fetchReadView() *syserr.Error {
if len(s.readView) > 0 {
return nil
}
-
s.readView = nil
s.sender = tcpip.FullAddress{}
v, cms, err := s.Endpoint.Read(&s.sender)
if err != nil {
+ atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
}
s.readView = v
s.readCM = cms
+ atomic.StoreUint32(&s.readViewHasData, 1)
return nil
}
@@ -623,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check our cached value iff the caller asked for readability and the
// endpoint itself is currently not readable.
if (mask & ^r & waiter.EventIn) != 0 {
- s.readMu.Lock()
- if len(s.readView) > 0 {
+ if atomic.LoadUint32(&s.readViewHasData) == 1 {
r |= waiter.EventIn
}
- s.readMu.Unlock()
}
return r
@@ -712,14 +720,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
- if err != nil {
- return err
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
}
- if err := s.checkFamily(family, true /* exact */); err != nil {
- return err
+
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
}
- addr = s.mapFamily(addr, family)
// Issue the bind request to the endpoint.
return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
@@ -2304,6 +2342,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
}
copied += n
s.readView.TrimFront(n)
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
@@ -2350,9 +2392,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// caller-supplied buffer.
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
- s.readMu.Unlock()
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
+ s.readMu.Unlock()
return n, 0, nil, 0, cmsg, err
}
@@ -2426,6 +2468,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readView.TrimFront(int(n))
}
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
var flags int
if msgLen > int(n) {
flags |= linux.MSG_TRUNC
@@ -2637,7 +2683,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
// Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
v += len(s.readView)
+ s.readMu.Unlock()
if v > math.MaxInt32 {
v = math.MaxInt32
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 0692482e9..f5fa18136 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -200,36 +199,66 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
switch stats := stat.(type) {
+ case *inet.StatDev:
+ for _, ni := range s.Stack.NICInfo() {
+ if ni.Name != arg {
+ continue
+ }
+ // TODO(gvisor.dev/issue/2103) Support stubbed stats.
+ *stats = inet.StatDev{
+ // Receive section.
+ ni.Stats.Rx.Bytes.Value(), // bytes.
+ ni.Stats.Rx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // frame.
+ 0, // compressed.
+ 0, // multicast.
+ // Transmit section.
+ ni.Stats.Tx.Bytes.Value(), // bytes.
+ ni.Stats.Tx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // colls.
+ 0, // carrier.
+ 0, // compressed.
+ }
+ break
+ }
case *inet.StatSNMPIP:
ip := Metrics.IP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPIP{
- 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL.
+ 0, // Ip/Forwarding.
+ 0, // Ip/DefaultTTL.
ip.PacketsReceived.Value(), // InReceives.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors.
+ 0, // Ip/InHdrErrors.
ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards.
+ 0, // Ip/ForwDatagrams.
+ 0, // Ip/InUnknownProtos.
+ 0, // Ip/InDiscards.
ip.PacketsDelivered.Value(), // InDelivers.
ip.PacketsSent.Value(), // OutRequests.
ip.OutgoingPacketErrors.Value(), // OutDiscards.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates.
+ 0, // Ip/OutNoRoutes.
+ 0, // Support Ip/ReasmTimeout.
+ 0, // Support Ip/ReasmReqds.
+ 0, // Support Ip/ReasmOKs.
+ 0, // Support Ip/ReasmFails.
+ 0, // Support Ip/FragOKs.
+ 0, // Support Ip/FragFails.
+ 0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs.
+ 0, // Icmp/InMsgs.
Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors.
+ 0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
in.ParamProblem.Value(), // InParmProbs.
@@ -241,7 +270,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.TimestampReply.Value(), // InTimestampReps.
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs.
+ 0, // Icmp/OutMsgs.
Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
out.DstUnreachable.Value(), // OutDestUnreachs.
out.TimeExceeded.Value(), // OutTimeExcds.
@@ -277,15 +306,16 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
case *inet.StatSNMPUDP:
udp := Metrics.UDP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPUDP{
udp.PacketsReceived.Value(), // InDatagrams.
udp.UnknownPortErrors.Value(), // NoPorts.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors.
+ 0, // Udp/InErrors.
udp.PacketsSent.Value(), // OutDatagrams.
udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti.
+ 0, // Udp/SndbufErrors.
+ 0, // Udp/InCsumErrors.
+ 0, // Udp/IgnoredMulti.
}
default:
return syserr.ErrEndpointOperation.ToError()
@@ -332,7 +362,7 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (iptables.IPTables, error) {
+func (s *Stack) IPTables() (stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index c7883e68e..0d24fd3c4 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -42,6 +42,8 @@ go_library(
"sys_socket.go",
"sys_splice.go",
"sys_stat.go",
+ "sys_stat_amd64.go",
+ "sys_stat_arm64.go",
"sys_sync.go",
"sys_sysinfo.go",
"sys_syslog.go",
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index d10a9bed8..35a98212a 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -514,7 +514,7 @@ func (ac accessContext) Value(key interface{}) interface{} {
}
}
-func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode uint) error {
+func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode uint) error {
const rOK = 4
const wOK = 2
const xOK = 1
@@ -529,7 +529,7 @@ func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode
return syserror.EINVAL
}
- return fileOpOn(t, dirFD, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return fileOpOn(t, dirFD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
// access(2) and faccessat(2) check permissions using real
// UID/GID, not effective UID/GID.
//
@@ -564,17 +564,23 @@ func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
addr := args[0].Pointer()
mode := args[1].ModeT()
- return 0, nil, accessAt(t, linux.AT_FDCWD, addr, true, mode)
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
}
// Faccessat implements linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
dirFD := args[0].Int()
addr := args[1].Pointer()
mode := args[2].ModeT()
- flags := args[3].Int()
- return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
+ return 0, nil, accessAt(t, dirFD, addr, mode)
}
// LINT.ThenChange(vfs2/filesystem.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 701b27b4a..a11a87cd1 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -25,24 +25,6 @@ import (
// LINT.IfChange
-func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
- return linux.Stat{
- Dev: sattr.DeviceID,
- Ino: sattr.InodeID,
- Nlink: uattr.Links,
- Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
- UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
- GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
- Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
- Size: uattr.Size,
- Blksize: sattr.BlockSize,
- Blocks: uattr.Usage / 512,
- ATime: uattr.AccessTime.Timespec(),
- MTime: uattr.ModificationTime.Timespec(),
- CTime: uattr.StatusChangeTime.Timespec(),
- }
-}
-
// Stat implements linux syscall stat(2).
func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -154,7 +136,10 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
mask := args[3].Uint()
statxAddr := args[4].Pointer()
- if mask&linux.STATX__RESERVED > 0 {
+ if mask&linux.STATX__RESERVED != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.AT_SYMLINK_NOFOLLOW|linux.AT_EMPTY_PATH|linux.AT_STATX_SYNC_TYPE) != 0 {
return 0, nil, syserror.EINVAL
}
if flags&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
new file mode 100644
index 000000000..0a04a6113
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uattr.Links,
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: sattr.BlockSize,
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_amd64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
new file mode 100644
index 000000000..5a3b1bfad
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uint32(uattr.Links),
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: int32(sattr.BlockSize),
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_arm64.go)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index f51761e81..2eb210014 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -22,6 +22,8 @@ go_library(
"read_write.go",
"setstat.go",
"stat.go",
+ "stat_amd64.go",
+ "stat_arm64.go",
"sync.go",
"xattr.go",
],
@@ -29,6 +31,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index fc5ceea4c..a859095e2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -250,7 +250,7 @@ func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
if err != nil {
return err
}
- tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
if err != nil {
return err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
index ddc140b65..a61cc5059 100644
--- a/pkg/sentry/syscalls/linux/vfs2/getdents.go
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -97,7 +97,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
// char d_name[]; /* Filename (null-terminated) */
// };
size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
- if size < cb.remaining {
+ if size > cb.remaining {
return syserror.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
@@ -125,7 +125,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
}
size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
- if size < cb.remaining {
+ if size > cb.remaining {
return syserror.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 9250659ff..136453ccc 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -173,12 +173,13 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
- return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ err = setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
})
+ return 0, nil, handleSetSizeError(t, err)
}
// Ftruncate implements Linux syscall ftruncate(2).
@@ -196,12 +197,13 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer file.DecRef()
- return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ err := file.SetStat(t, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
})
+ return 0, nil, handleSetSizeError(t, err)
}
// Utime implements Linux syscall utime(2).
@@ -378,3 +380,12 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa
FollowFinalSymlink: bool(shouldFollowFinalSymlink),
}, opts)
}
+
+func handleSetSizeError(t *kernel.Task, err error) error {
+ if err == syserror.ErrExceedsFileSizeLimit {
+ // Convert error to EFBIG and send a SIGXFSZ per setrlimit(2).
+ t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
+ return syserror.EFBIG
+ }
+ return err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
index dca8d7011..fdfe49243 100644
--- a/pkg/sentry/syscalls/linux/vfs2/stat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -113,29 +114,6 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
return stat.CopyOut(t, statAddr)
}
-// This takes both input and output as pointer arguments to avoid copying large
-// structs.
-func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
- // Linux just copies fields from struct kstat without regard to struct
- // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
- userns := t.UserNamespace()
- *stat = linux.Stat{
- Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
- Ino: statx.Ino,
- Nlink: uint64(statx.Nlink),
- Mode: uint32(statx.Mode),
- UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
- GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
- Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
- Size: int64(statx.Size),
- Blksize: int64(statx.Blksize),
- Blocks: int64(statx.Blocks),
- ATime: timespecFromStatxTimestamp(statx.Atime),
- MTime: timespecFromStatxTimestamp(statx.Mtime),
- CTime: timespecFromStatxTimestamp(statx.Ctime),
- }
-}
-
func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
return linux.Timespec{
Sec: sxts.Sec,
@@ -173,7 +151,15 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
mask := args[3].Uint()
statxAddr := args[4].Pointer()
- if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW|linux.AT_STATX_SYNC_TYPE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ // Make sure that only one sync type option is set.
+ syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE)
+ if syncType != 0 && !bits.IsPowerOfTwo32(syncType) {
+ return 0, nil, syserror.EINVAL
+ }
+ if mask&linux.STATX__RESERVED != 0 {
return 0, nil, syserror.EINVAL
}
@@ -251,14 +237,65 @@ func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Access implements Linux syscall access(2).
func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- // FIXME(jamieliu): actually implement
- return 0, nil, nil
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
}
-// Faccessat implements Linux syscall access(2).
+// Faccessat implements Linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- // FIXME(jamieliu): actually implement
- return 0, nil, nil
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+
+ return 0, nil, accessAt(t, dirfd, addr, mode)
+}
+
+func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ const rOK = 4
+ const wOK = 2
+ const xOK = 1
+
+ // Sanity check the mode.
+ if mode&^(rOK|wOK|xOK) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ // access(2) and faccessat(2) check permissions using real
+ // UID/GID, not effective UID/GID.
+ //
+ // "access() needs to use the real uid/gid, not the effective
+ // uid/gid. We do this by temporarily clearing all FS-related
+ // capabilities and switching the fsuid/fsgid around to the
+ // real ones." -fs/open.c:faccessat
+ creds := t.Credentials().Fork()
+ creds.EffectiveKUID = creds.RealKUID
+ creds.EffectiveKGID = creds.RealKGID
+ if creds.EffectiveKUID.In(creds.UserNamespace) == auth.RootUID {
+ creds.EffectiveCaps = creds.PermittedCaps
+ } else {
+ creds.EffectiveCaps = 0
+ }
+
+ return t.Kernel().VFS().AccessAt(t, creds, vfs.AccessTypes(mode), &tpop.pop)
}
// Readlinkat implements Linux syscall mknodat(2).
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
new file mode 100644
index 000000000..2da538fc6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
new file mode 100644
index 000000000..88b9c7627
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint32(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int32(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 07c8383e6..a2a06fc8f 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -42,17 +42,22 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/fd",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/log",
+ "//pkg/safemem",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/limits",
"//pkg/sentry/memmap",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 2db25be49..a62e43589 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -41,7 +41,14 @@ func (vfs *VirtualFilesystem) NewAnonVirtualDentry(name string) VirtualDentry {
}
}
-const anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
+const (
+ anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
+
+ // Mode, UID, and GID for a generic anonfs file.
+ anonFileMode = 0600 // no type is correct
+ anonFileUID = auth.RootKUID
+ anonFileGID = auth.RootKGID
+)
// anonFilesystem is the implementation of FilesystemImpl that backs
// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
@@ -69,6 +76,16 @@ func (fs *anonFilesystem) Sync(ctx context.Context) error {
return nil
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+//
+// TODO(gvisor.dev/issue/1965): Implement access permissions.
+func (fs *anonFilesystem) AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return GenericCheckPermissions(creds, ats, anonFileMode, anonFileUID, anonFileGID)
+}
+
// GetDentryAt implements FilesystemImpl.GetDentryAt.
func (fs *anonFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) {
if !rp.Done() {
@@ -167,9 +184,9 @@ func (fs *anonFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts St
Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
Blksize: anonfsBlockSize,
Nlink: 1,
- UID: uint32(auth.RootKUID),
- GID: uint32(auth.RootKGID),
- Mode: 0600, // no type is correct
+ UID: uint32(anonFileUID),
+ GID: uint32(anonFileGID),
+ Mode: anonFileMode,
Ino: 1,
Size: 0,
Blocks: 0,
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 9a1ad630c..8ee549dc2 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -286,7 +286,8 @@ type FileDescriptionImpl interface {
Stat(ctx context.Context, opts StatOptions) (linux.Statx, error)
// SetStat updates metadata for the file represented by the
- // FileDescription.
+ // FileDescription. Implementations are responsible for checking if the
+ // operation can be performed (see vfs.CheckSetStat() for common checks).
SetStat(ctx context.Context, opts SetStatOptions) error
// StatFS returns metadata for the filesystem containing the file
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index c2a52ec1b..d45e602ce 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -33,8 +33,8 @@ import (
// implementations to adapt:
// - Have a local fileDescription struct (containing FileDescription) which
// embeds FileDescriptionDefaultImpl and overrides the default methods
-// which are common to all fd implementations for that for that filesystem
-// like StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
+// which are common to all fd implementations for that filesystem like
+// StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
// - This should be embedded in all file description implementations as the
// first field by value.
// - Directory FDs would also embed DirectoryFileDescriptionDefaultImpl.
@@ -339,6 +339,11 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src
if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
return 0, syserror.EOPNOTSUPP
}
+ limit, err := CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
writable, ok := fd.data.(WritableDynamicBytesSource)
if !ok {
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 556976d0b..332decce6 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
// A Filesystem is a tree of nodes represented by Dentries, which forms part of
@@ -144,6 +145,9 @@ type FilesystemImpl interface {
// file data to be written to the underlying [filesystem]", as by syncfs(2).
Sync(ctx context.Context) error
+ // AccessAt checks whether a user with creds can access the file at rp.
+ AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error
+
// GetDentryAt returns a Dentry representing the file at rp. A reference is
// taken on the returned Dentry.
//
@@ -362,7 +366,9 @@ type FilesystemImpl interface {
// ResolvingPath.Resolve*(), then !rp.Done().
RmdirAt(ctx context.Context, rp *ResolvingPath) error
- // SetStatAt updates metadata for the file at the given path.
+ // SetStatAt updates metadata for the file at the given path. Implementations
+ // are responsible for checking if the operation can be performed
+ // (see vfs.CheckSetStat() for common checks).
//
// Errors:
//
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 31a4e5480..05f6233f9 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -74,6 +74,10 @@ type Mount struct {
// umounted is true. umounted is protected by VirtualFilesystem.mountMu.
umounted bool
+ // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
+ // for MS_RDONLY which is tracked in "writers".
+ flags MountFlags
+
// The lower 63 bits of writers is the number of calls to
// Mount.CheckBeginWrite() that have not yet been paired with a call to
// Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
@@ -81,6 +85,21 @@ type Mount struct {
writers int64
}
+func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
+ mnt := &Mount{
+ vfs: vfs,
+ fs: fs,
+ root: root,
+ flags: opts.Flags,
+ ns: mntns,
+ refs: 1,
+ }
+ if opts.ReadOnly {
+ mnt.setReadOnlyLocked(true)
+ }
+ return mnt
+}
+
// A MountNamespace is a collection of Mounts.
//
// MountNamespaces are reference-counted. Unless otherwise specified, all
@@ -129,13 +148,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
refs: 1,
mountpoints: make(map[*Dentry]uint32),
}
- mntns.root = &Mount{
- vfs: vfs,
- fs: fs,
- root: root,
- ns: mntns,
- refs: 1,
- }
+ mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{})
return mntns, nil
}
@@ -148,12 +161,7 @@ func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry,
if root != nil {
root.IncRef()
}
- return &Mount{
- vfs: vfs,
- fs: fs,
- root: root,
- refs: 1,
- }, nil
+ return newMount(vfs, fs, root, nil /* mntns */, opts), nil
}
// MountAt creates and mounts a Filesystem configured by the given arguments.
@@ -218,13 +226,7 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
// are directories, or neither are, and returns ENOTDIR if this is not the
// case.
mntns := vd.mount.ns
- mnt := &Mount{
- vfs: vfs,
- fs: fs,
- root: root,
- ns: mntns,
- refs: 1,
- }
+ mnt := newMount(vfs, fs, root, mntns, opts)
vfs.mounts.seq.BeginWrite()
vfs.connectLocked(mnt, vd, mntns)
vfs.mounts.seq.EndWrite()
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 6af7fdac1..3e90dc4ed 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -46,8 +46,21 @@ type MknodOptions struct {
DevMinor uint32
}
+// MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+// MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers.
+type MountFlags struct {
+ // NoExec is equivalent to MS_NOEXEC.
+ NoExec bool
+}
+
// MountOptions contains options to VirtualFilesystem.MountAt().
type MountOptions struct {
+ // Flags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+ Flags MountFlags
+
+ // ReadOnly is equivalent to MS_RDONLY.
+ ReadOnly bool
+
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
GetFilesystemOptions GetFilesystemOptions
@@ -75,7 +88,8 @@ type OpenOptions struct {
// FileExec is set when the file is being opened to be executed.
// VirtualFilesystem.OpenAt() checks that the caller has execute permissions
- // on the file, and that the file is a regular file.
+ // on the file, that the file is a regular file, and that the mount doesn't
+ // have MS_NOEXEC set.
FileExec bool
}
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 8e250998a..f9647f90e 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -15,8 +15,12 @@
package vfs
import (
+ "math"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -25,9 +29,9 @@ type AccessTypes uint16
// Bits in AccessTypes.
const (
+ MayExec AccessTypes = 1
+ MayWrite AccessTypes = 2
MayRead AccessTypes = 4
- MayWrite = 2
- MayExec = 1
)
// OnlyRead returns true if access _only_ allows read.
@@ -52,16 +56,17 @@ func (a AccessTypes) MayExec() bool {
// GenericCheckPermissions checks that creds has the given access rights on a
// file with the given permissions, UID, and GID, subject to the rules of
-// fs/namei.c:generic_permission(). isDir is true if the file is a directory.
-func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir bool, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+// fs/namei.c:generic_permission().
+func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
// Check permission bits.
- perms := mode
+ perms := uint16(mode.Permissions())
if creds.EffectiveKUID == kuid {
perms >>= 6
} else if creds.InGroup(kgid) {
perms >>= 3
}
if uint16(ats)&perms == uint16(ats) {
+ // All permission bits match, access granted.
return nil
}
@@ -73,7 +78,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
}
// CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary
// directories, and read arbitrary non-directory files.
- if (isDir && !ats.MayWrite()) || ats.OnlyRead() {
+ if (mode.IsDir() && !ats.MayWrite()) || ats.OnlyRead() {
if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) {
return nil
}
@@ -81,7 +86,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write
// access to non-directory files, and execute access to non-directory files
// for which at least one execute bit is set.
- if isDir || !ats.MayExec() || (mode&0111 != 0) {
+ if mode.IsDir() || !ats.MayExec() || (mode.Permissions()&0111 != 0) {
if creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
return nil
}
@@ -147,7 +152,16 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
// CheckSetStat checks that creds has permission to change the metadata of a
// file with the given permissions, UID, and GID as specified by stat, subject
// to the rules of Linux's fs/attr.c:setattr_prepare().
-func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ limit, err := CheckLimit(ctx, 0, int64(stat.Size))
+ if err != nil {
+ return err
+ }
+ if limit < int64(stat.Size) {
+ return syserror.ErrExceedsFileSizeLimit
+ }
+ }
if stat.Mask&linux.STATX_MODE != 0 {
if !CanActAsOwner(creds, kuid) {
return syserror.EPERM
@@ -177,11 +191,7 @@ func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid
(stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) {
return syserror.EPERM
}
- // isDir is irrelevant in the following call to
- // GenericCheckPermissions since ats == MayWrite means that
- // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE
- // applies, regardless of isDir.
- if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
return err
}
}
@@ -205,3 +215,21 @@ func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool {
func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool {
return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok()
}
+
+// CheckLimit enforces file size rlimits. It returns error if the write
+// operation must not proceed. Otherwise it returns the max length allowed to
+// without violating the limit.
+func CheckLimit(ctx context.Context, offset, size int64) (int64, error) {
+ fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur
+ if fileSizeLimit > math.MaxInt64 {
+ return size, nil
+ }
+ if offset >= int64(fileSizeLimit) {
+ return 0, syserror.ErrExceedsFileSizeLimit
+ }
+ remaining := int64(fileSizeLimit) - offset
+ if remaining < size {
+ return remaining, nil
+ }
+ return size, nil
+}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index eb4ebb511..8f31495da 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -329,10 +329,22 @@ func (rp *ResolvingPath) ResolveComponent(d *Dentry) (*Dentry, error) {
// component in pcs represents a symbolic link, the symbolic link should be
// followed.
//
+// If path is terminated with '/', the '/' is considered the last element and
+// any symlink before that is followed:
+// - For most non-creating walks, the last path component is handled by
+// fs/namei.c:lookup_last(), which sets LOOKUP_FOLLOW if the first byte
+// after the path component is non-NULL (which is only possible if it's '/')
+// and the path component is of type LAST_NORM.
+//
+// - For open/openat/openat2 without O_CREAT, the last path component is
+// handled by fs/namei.c:do_last(), which does the same, though without the
+// LAST_NORM check.
+//
// Preconditions: !rp.Done().
func (rp *ResolvingPath) ShouldFollowSymlink() bool {
- // Non-final symlinks are always followed.
- return rp.flags&rpflagsFollowFinalSymlink != 0 || !rp.Final()
+ // Non-final symlinks are always followed. Paths terminated with '/' are also
+ // always followed.
+ return rp.flags&rpflagsFollowFinalSymlink != 0 || !rp.Final() || rp.MustBeDir()
}
// HandleSymlink is called when the current path component is a symbolic link
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index bde81e1ef..03d1fb943 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -174,6 +174,23 @@ type PathOperation struct {
FollowFinalSymlink bool
}
+// AccessAt checks whether a user with creds has access to the file at
+// the given path.
+func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credentials, ats AccessTypes, pop *PathOperation) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
// GetDentryAt returns a VirtualDentry representing the given path, at which a
// file must exist. A reference is taken on the returned VirtualDentry.
func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
@@ -385,9 +402,12 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
if err == nil {
vfs.putResolvingPath(rp)
- // TODO(gvisor.dev/issue/1193): Move inside fsimpl to avoid another call
- // to FileDescription.Stat().
if opts.FileExec {
+ if fd.Mount().flags.NoExec {
+ fd.DecRef()
+ return nil, syserror.EACCES
+ }
+
// Only a regular file can be executed.
stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
if err != nil {
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index bfb2fac26..f7d6009a0 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -221,7 +221,7 @@ func (w *Watchdog) waitForStart() {
return
}
var buf bytes.Buffer
- buf.WriteString("Watchdog.Start() not called within %s:\n")
+ buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
w.doAction(w.StartupTimeoutAction, false, &buf)
}
@@ -325,7 +325,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
- buf.WriteString("Watchdog goroutine is stuck:\n")
+ buf.WriteString("Watchdog goroutine is stuck:")
w.doAction(w.TaskTimeoutAction, false, &buf)
}
@@ -359,7 +359,7 @@ func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) {
case <-metricsEmitted:
case <-time.After(1 * time.Second):
}
- panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String()))
+ panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String()))
default:
panic(fmt.Sprintf("Unknown watchdog action %v", action))
diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go
index d2d7132fa..0d4316254 100644
--- a/pkg/sync/aliases.go
+++ b/pkg/sync/aliases.go
@@ -29,3 +29,8 @@ type (
// Map is an alias of sync.Map.
Map = sync.Map
)
+
+// NewCond is a wrapper around sync.NewCond.
+func NewCond(l Locker) *Cond {
+ return sync.NewCond(l)
+}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 26f7ba86b..454e07662 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -5,8 +5,6 @@ package(licenses = ["notice"])
go_library(
name = "tcpip",
srcs = [
- "packet_buffer.go",
- "packet_buffer_state.go",
"tcpip.go",
"time_unsafe.go",
"timer.go",
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index a984f1712..e57d45f2a 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -22,6 +22,7 @@ go_test(
size = "small",
srcs = ["gonet_test.go"],
library = ":gonet",
+ tags = ["flaky"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index c6c160dfc..8dc0f7c0e 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -785,6 +785,52 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
}
}
+// ndpOptions checks that optsBuf only contains opts.
+func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
+ t.Helper()
+
+ it, err := optsBuf.Iter(true)
+ if err != nil {
+ t.Errorf("optsBuf.Iter(true): %s", err)
+ return
+ }
+
+ i := 0
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ t.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
+
+ if i >= len(opts) {
+ t.Errorf("got unexpected option: %s", opt)
+ continue
+ }
+
+ switch wantOpt := opts[i].(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ default:
+ t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
+ }
+
+ i++
+ }
+
+ if missing := opts[i:]; len(missing) > 0 {
+ t.Errorf("missing options: %s", missing)
+ }
+}
+
// NDPNSOptions creates a checker that checks that the packet contains the
// provided NDP options within an NDP Neighbor Solicitation message.
//
@@ -796,47 +842,31 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker {
icmp := h.(header.ICMPv6)
ns := header.NDPNeighborSolicit(icmp.NDPPayload())
- it, err := ns.Options().Iter(true)
- if err != nil {
- t.Errorf("opts.Iter(true): %s", err)
- return
- }
-
- i := 0
- for {
- opt, done, _ := it.Next()
- if done {
- break
- }
-
- if i >= len(opts) {
- t.Errorf("got unexpected option: %s", opt)
- continue
- }
-
- switch wantOpt := opts[i].(type) {
- case header.NDPSourceLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- default:
- panic("not implemented")
- }
-
- i++
- }
-
- if missing := opts[i:]; len(missing) > 0 {
- t.Errorf("missing options: %s", missing)
- }
+ ndpOptions(t, ns.Options(), opts)
}
}
// NDPRS creates a checker that checks that the packet contains a valid NDP
// Router Solicitation message (as per the raw wire format).
-func NDPRS() NetworkChecker {
- return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize)
+//
+// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPRS as far as the size of the message is concerned. The values within the
+// message are up to checkers to validate.
+func NDPRS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
+}
+
+// NDPRSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Router Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPRS message as far as the size is concerned.
+func NDPRSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ ndpOptions(t, rs.Options(), opts)
+ }
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e5360e7c1..76839eb92 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -38,7 +38,8 @@ const (
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv4Fields struct {
- // IHL is the "internet header length" field of an IPv4 packet.
+ // IHL is the "internet header length" field of an IPv4 packet. The value
+ // is in bytes.
IHL uint8
// TOS is the "type of service" field of an IPv4 packet.
@@ -138,7 +139,7 @@ func IPVersion(b []byte) int {
}
// HeaderLength returns the value of the "header length" field of the ipv4
-// header.
+// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & 0xf) * 4
}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 82cfe785c..13480687d 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -81,7 +81,8 @@ type TCPFields struct {
// AckNum is the "acknowledgement number" field of a TCP packet.
AckNum uint32
- // DataOffset is the "data offset" field of a TCP packet.
+ // DataOffset is the "data offset" field of a TCP packet. It is the length of
+ // the TCP header in bytes.
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
@@ -213,7 +214,8 @@ func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
}
-// DataOffset returns the "data offset" field of the tcp header.
+// DataOffset returns the "data offset" field of the tcp header. The return
+// value is the length of the TCP header in bytes.
func (b TCP) DataOffset() uint8 {
return (b[TCPDataOffset] >> 4) * 4
}
@@ -238,6 +240,11 @@ func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
}
+// UrgentPointer returns the "urgent pointer" field of the tcp header.
+func (b TCP) UrgentPointer() uint16 {
+ return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
+}
+
// SetSourcePort sets the "source port" field of the tcp header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
@@ -253,6 +260,37 @@ func (b TCP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
}
+// SetDataOffset sets the data offset field of the tcp header. headerLen should
+// be the length of the TCP header in bytes.
+func (b TCP) SetDataOffset(headerLen uint8) {
+ b[TCPDataOffset] = (headerLen / 4) << 4
+}
+
+// SetSequenceNumber sets the sequence number field of the tcp header.
+func (b TCP) SetSequenceNumber(seqNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
+}
+
+// SetAckNumber sets the ack number field of the tcp header.
+func (b TCP) SetAckNumber(ackNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
+}
+
+// SetFlags sets the flags field of the tcp header.
+func (b TCP) SetFlags(flags uint8) {
+ b[TCPFlagsOffset] = flags
+}
+
+// SetWindowSize sets the window size field of the tcp header.
+func (b TCP) SetWindowSize(rcvwnd uint16) {
+ binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
+}
+
+// SetUrgentPoiner sets the window size field of the tcp header.
+func (b TCP) SetUrgentPoiner(urgentPointer uint16) {
+ binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
+}
+
// CalculateChecksum calculates the checksum of the tcp segment.
// partialChecksum is the checksum of the network-layer pseudo-header
// and the checksum of the segment data.
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
deleted file mode 100644
index d1b73cfdf..000000000
--- a/pkg/tcpip/iptables/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "iptables",
- srcs = [
- "iptables.go",
- "targets.go",
- "types.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
deleted file mode 100644
index 81a2e39a2..000000000
--- a/pkg/tcpip/iptables/targets.go
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package iptables
-
-import (
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-// AcceptTarget accepts packets.
-type AcceptTarget struct{}
-
-// Action implements Target.Action.
-func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
- return RuleAccept, 0
-}
-
-// DropTarget drops packets.
-type DropTarget struct{}
-
-// Action implements Target.Action.
-func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
- return RuleDrop, 0
-}
-
-// ErrorTarget logs an error and drops the packet. It represents a target that
-// should be unreachable.
-type ErrorTarget struct{}
-
-// Action implements Target.Action.
-func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
- log.Debugf("ErrorTarget triggered.")
- return RuleDrop, 0
-}
-
-// UserChainTarget marks a rule as the beginning of a user chain.
-type UserChainTarget struct {
- Name string
-}
-
-// Action implements Target.Action.
-func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
- panic("UserChainTarget should never be called.")
-}
-
-// ReturnTarget returns from the current chain. If the chain is a built-in, the
-// hook's underflow should be called.
-type ReturnTarget struct{}
-
-// Action implements Target.Action.
-func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
- return RuleReturn, 0
-}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 5944ba190..a8d6653ce 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -28,7 +28,7 @@ import (
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Pkt tcpip.PacketBuffer
+ Pkt stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
Route stack.Route
@@ -203,12 +203,12 @@ func (e *Endpoint) NumQueued() int {
}
// InjectInbound injects an inbound packet.
-func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) {
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
}
@@ -251,7 +251,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
// Clone r then release its resource so we only get the relevant fields from
// stack.Route without holding a reference to a NIC's endpoint.
route := r.Clone()
@@ -269,7 +269,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
// Clone r then release its resource so we only get the relevant fields from
// stack.Route without holding a reference to a NIC's endpoint.
route := r.Clone()
@@ -280,7 +280,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
off := pkt.DataOffset
size := pkt.DataSize
p := PacketInfo{
- Pkt: tcpip.PacketBuffer{
+ Pkt: stack.PacketBuffer{
Header: pkt.Header,
Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(),
},
@@ -301,7 +301,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := PacketInfo{
- Pkt: tcpip.PacketBuffer{Data: vv},
+ Pkt: stack.PacketBuffer{Data: vv},
Proto: 0,
GSO: nil,
}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index b7f60178e..3b3b6909b 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -386,7 +386,7 @@ const (
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
@@ -405,9 +405,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
eth.Encode(ethHdr)
}
+ fd := e.fds[pkt.Hash%uint32(len(e.fds))]
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
- vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
if gso != nil {
vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
if gso.NeedsCsum {
@@ -428,19 +428,20 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
}
- return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
+ vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
+ return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
}
if pkt.Data.Size() == 0 {
- return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View())
+ return rawfile.NonBlockingWrite(fd, pkt.Header.View())
}
- return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil)
+ return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil)
}
// WritePackets writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
var ethHdrBuf []byte
// hdr + data
iovLen := 2
@@ -467,7 +468,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
views := pkts[0].Data.Views()
/*
- * Each bondary in views can add one more iovec.
+ * Each boundary in views can add one more iovec.
*
* payload | | | |
* -----------------------------
@@ -551,7 +552,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
packets := 0
for packets < n {
- sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs)
+ fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))]
+ sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs)
if err != nil {
return packets, err
}
@@ -610,7 +612,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
// InjectInbound injects an inbound packet.
-func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt)
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 2066987eb..3bfb15a8e 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -45,40 +45,46 @@ const (
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
- contents tcpip.PacketBuffer
+ contents stack.PacketBuffer
}
type context struct {
- t *testing.T
- fds [2]int
- ep stack.LinkEndpoint
- ch chan packetInfo
- done chan struct{}
+ t *testing.T
+ readFDs []int
+ writeFDs []int
+ ep stack.LinkEndpoint
+ ch chan packetInfo
+ done chan struct{}
}
func newContext(t *testing.T, opt *Options) *context {
- fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ t.Fatalf("Socketpair failed: %v", err)
+ }
+ secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
if err != nil {
t.Fatalf("Socketpair failed: %v", err)
}
- done := make(chan struct{}, 1)
+ done := make(chan struct{}, 2)
opt.ClosedFunc = func(*tcpip.Error) {
done <- struct{}{}
}
- opt.FDs = []int{fds[1]}
+ opt.FDs = []int{firstFDPair[1], secondFDPair[1]}
ep, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
}
c := &context{
- t: t,
- fds: fds,
- ep: ep,
- ch: make(chan packetInfo, 100),
- done: done,
+ t: t,
+ readFDs: []int{firstFDPair[0], secondFDPair[0]},
+ writeFDs: opt.FDs,
+ ep: ep,
+ ch: make(chan packetInfo, 100),
+ done: done,
}
ep.Attach(c)
@@ -87,12 +93,17 @@ func newContext(t *testing.T, opt *Options) *context {
}
func (c *context) cleanup() {
- syscall.Close(c.fds[0])
+ for _, fd := range c.readFDs {
+ syscall.Close(fd)
+ }
+ <-c.done
<-c.done
- syscall.Close(c.fds[1])
+ for _, fd := range c.writeFDs {
+ syscall.Close(fd)
+ }
}
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
c.ch <- packetInfo{remote, protocol, pkt}
}
@@ -136,7 +147,7 @@ func TestAddress(t *testing.T) {
}
}
-func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
+func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) {
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
defer c.cleanup()
@@ -168,16 +179,18 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
+ Hash: hash,
}); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
- // Read from fd, then compare with what we wrote.
+ // Read from the corresponding FD, then compare with what we wrote.
b = make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
+ fd := c.readFDs[hash%uint32(len(c.readFDs))]
+ n, err := syscall.Read(fd, b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
@@ -238,7 +251,7 @@ func TestWritePacket(t *testing.T) {
t.Run(
fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso),
func(t *testing.T) {
- testWritePacket(t, plen, eth, gso)
+ testWritePacket(t, plen, eth, gso, 0)
},
)
}
@@ -246,6 +259,27 @@ func TestWritePacket(t *testing.T) {
}
}
+func TestHashedWritePacket(t *testing.T) {
+ lengths := []int{0, 100, 1000}
+ eths := []bool{true, false}
+ gsos := []uint32{0, 32768}
+ hashes := []uint32{0, 1}
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ for _, gso := range gsos {
+ for _, hash := range hashes {
+ t.Run(
+ fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash),
+ func(t *testing.T) {
+ testWritePacket(t, plen, eth, gso, hash)
+ },
+ )
+ }
+ }
+ }
+ }
+}
+
func TestPreserveSrcAddress(t *testing.T) {
baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
@@ -261,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) {
// WritePacket panics given a prependable with anything less than
// the minimum size of the ethernet header.
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{
Header: hdr,
Data: buffer.VectorisedView{},
}); err != nil {
@@ -270,7 +304,7 @@ func TestPreserveSrcAddress(t *testing.T) {
// Read from the FD, then compare with what we wrote.
b := make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
+ n, err := syscall.Read(c.readFDs[0], b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
@@ -314,7 +348,7 @@ func TestDeliverPacket(t *testing.T) {
}
// Write packet via the file descriptor.
- if _, err := syscall.Write(c.fds[0], all); err != nil {
+ if _, err := syscall.Write(c.readFDs[0], all); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -324,7 +358,7 @@ func TestDeliverPacket(t *testing.T) {
want := packetInfo{
raddr: raddr,
proto: proto,
- contents: tcpip.PacketBuffer{
+ contents: stack.PacketBuffer{
Data: buffer.View(b).ToVectorisedView(),
LinkHeader: buffer.View(hdr),
},
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
index 97a477b61..d81858353 100644
--- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -24,9 +24,10 @@ import (
const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{}))
func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) {
- sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
- sh.Data = uintptr(unsafe.Pointer(hdr))
- sh.Len = virtioNetHdrSize
- sh.Cap = virtioNetHdrSize
+ *(*reflect.SliceHeader)(unsafe.Pointer(&slice)) = reflect.SliceHeader{
+ Data: uintptr((unsafe.Pointer(hdr))),
+ Len: virtioNetHdrSize,
+ Cap: virtioNetHdrSize,
+ }
return
}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 62ed1e569..fe2bf3b0b 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -190,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{
Data: buffer.View(pkt).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index c67d684ce..cb4cbea69 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(n, BufConfig)
- pkt := tcpip.PacketBuffer{
+ pkt := stack.PacketBuffer{
Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
LinkHeader: buffer.View(eth),
}
@@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(k, int(n), BufConfig)
- pkt := tcpip.PacketBuffer{
+ pkt := stack.PacketBuffer{
Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
LinkHeader: buffer.View(eth),
}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 499cc608f..4039753b7 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,7 +76,7 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
@@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw
// Because we're immediately turning around and writing the packet back
// to the rx path, we intentionally don't preserve the remote and local
// link addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
// There should be an ethernet header at the beginning of vv.
linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize])
vv.TrimFront(len(linkHeader))
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
Data: vv,
LinkHeader: buffer.View(linkHeader),
})
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 445b22c17..f5973066d 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -80,14 +80,14 @@ func (m *InjectableEndpoint) IsAttached() bool {
}
// InjectInbound implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt)
}
// WritePackets writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
// r.RemoteAddress has a route registered in this endpoint.
-func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
endpoint, ok := m.routes[r.RemoteAddress]
if !ok {
return 0, tcpip.ErrNoRoute
@@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts [
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
return endpoint.WritePacket(r, gso, protocol, pkt)
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 63b249837..87c734c1f 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) {
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
})
@@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
hdr := buffer.NewPrependable(1)
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buffer.NewView(0).ToVectorisedView(),
})
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 655e537c4..6461d0108 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
// Add the ethernet header here.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
@@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
// Send packet up the stack.
eth := header.Ethernet(b[:header.EthernetMinimumSize])
- d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{
+ d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{
Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 5c729a439..27ea3f531 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
addr: remoteLinkAddr,
@@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) {
randomFill(buf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
Header: hdr,
}); err != nil {
t.Fatalf("WritePacket failed: %v", err)
@@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) {
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
})
@@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: uu,
}); err != want {
@@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Attempt to write the one-buffer packet again. It must succeed.
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 3392b7edd..0a6b8945c 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -123,7 +123,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("recv", protocol, pkt.Data.First(), nil)
}
@@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("send", protocol, pkt.Header.View(), gso)
}
@@ -232,7 +232,7 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
e.dumpPacket(gso, protocol, pkt)
return e.lower.WritePacket(r, gso, protocol, pkt)
}
@@ -240,10 +240,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
view := pkts[0].Data.ToView()
for _, pkt := range pkts {
- e.dumpPacket(gso, protocol, tcpip.PacketBuffer{
+ e.dumpPacket(gso, protocol, stack.PacketBuffer{
Header: pkt.Header,
Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(),
})
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 6ff47a742..617446ea2 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -98,7 +98,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
prefix = "tap"
}
- endpoint, err := attachOrCreateNIC(s, name, prefix)
+ linkCaps := stack.CapabilityNone
+ if isTap {
+ linkCaps |= stack.CapabilityResolutionRequired
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
return syserror.EINVAL
}
@@ -109,7 +114,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
return nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
@@ -135,6 +140,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error
nicID: id,
name: name,
}
+ endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
endpoint.name = fmt.Sprintf("%s%d", prefix, id)
}
@@ -207,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) {
remote = tcpip.LinkAddress(zeroMAC[:])
}
- pkt := tcpip.PacketBuffer{
+ pkt := stack.PacketBuffer{
Data: buffer.View(data).ToVectorisedView(),
}
if ethHdr != nil {
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a8de38979..52fe397bf 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
if !e.dispatchGate.Enter() {
return
}
@@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
@@ -112,7 +112,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
if !e.writeGate.Enter() {
return len(pkts), nil
}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 31b11a27a..88224e494 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -35,7 +35,7 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
e.dispatchCount++
}
@@ -65,13 +65,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
e.writeCount++
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
e.writeCount += len(pkts)
return len(pkts), nil
}
@@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) {
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index e9fcc89a8..255098372 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -79,20 +79,20 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
v := pkt.Data.First()
h := header.ARP(v)
if !h.IsValid() {
@@ -113,7 +113,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
Header: hdr,
})
fallthrough // also fill the cache from requests
@@ -167,7 +167,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
Header: hdr,
})
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 03cf03b6d..b3e239ac7 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{
Data: v.ToVectorisedView(),
})
}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 92f2aa13a..f42abc4bb 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -115,10 +115,12 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
if f.size > f.highLimit {
- tail := f.rList.Back()
- for f.size > f.lowLimit && tail != nil {
+ for f.size > f.lowLimit {
+ tail := f.rList.Back()
+ if tail == nil {
+ break
+ }
f.release(tail)
- tail = tail.Prev()
}
}
f.mu.Unlock()
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f4d78f8c6..4950d69fc 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) {
t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
@@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
@@ -150,7 +150,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -246,7 +246,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -289,7 +289,7 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: view.ToVectorisedView(),
})
if o.dataCalls != 1 {
@@ -379,7 +379,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.extra = c.expectedExtra
vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: vv,
})
if want := c.expectedCount; o.controlCalls != want {
@@ -444,7 +444,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: frag1.ToVectorisedView(),
})
if o.dataCalls != 0 {
@@ -452,7 +452,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send second segment.
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: frag2.ToVectorisedView(),
})
if o.dataCalls != 1 {
@@ -487,7 +487,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -530,7 +530,7 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: view.ToVectorisedView(),
})
if o.dataCalls != 1 {
@@ -644,7 +644,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: view[:len(view)-c.trunc].ToVectorisedView(),
})
if want := c.expectedCount; o.controlCalls != want {
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 0fef2b1f1..880ea7de2 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -13,7 +13,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 32bf39e43..c4bf1ba5c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,7 +15,6 @@
package ipv4
import (
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -25,7 +24,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
h := header.IPv4(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that the IP
@@ -53,7 +52,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
v := pkt.Data.First()
@@ -85,7 +84,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{
Data: pkt.Data.Clone(nil),
NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
})
@@ -99,7 +98,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: vv,
TransportHeader: buffer.View(pkt),
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 4f1742938..b3ee6000e 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -125,7 +124,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
// packet's stated length matches the length of the header+payload. mtu
// includes the IP header and options. This does not support the DontFragment
// IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
ip := header.IPv4(pkt.Header.View())
flags := ip.Flags()
@@ -165,7 +164,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
if i > 0 {
newPayload := pkt.Data.Clone(nil)
newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -184,7 +183,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
newPayload := pkt.Data.Clone(nil)
newPayloadLength := outerMTU - pkt.Header.UsedLength()
newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -198,7 +197,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
startOfHdr := pkt.Header
startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
Header: startOfHdr,
Data: emptyVV,
NetworkHeader: buffer.View(h),
@@ -241,7 +240,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
@@ -253,7 +252,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ e.HandlePacket(&loopedR, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -273,7 +272,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
@@ -292,7 +291,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
ip := header.IPv4(pkt.Data.First())
@@ -344,7 +343,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuf
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
headerView := pkt.Data.First()
h := header.IPv4(headerView)
if !h.IsValid(pkt.Data.Size()) {
@@ -361,7 +360,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
ipt := e.stack.IPTables()
- if ok := ipt.Check(iptables.Input, pkt); !ok {
+ if ok := ipt.Check(stack.Input, pkt); !ok {
// iptables is telling us to drop the packet.
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index e900f1b45..5a864d832 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -113,7 +113,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) {
+func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
@@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketIn
type errorChannel struct {
*channel.Endpoint
- Ch chan tcpip.PacketBuffer
+ Ch chan stack.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
@@ -183,7 +183,7 @@ type errorChannel struct {
func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
return &errorChannel{
Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan tcpip.PacketBuffer, size),
+ Ch: make(chan stack.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
}
@@ -202,7 +202,7 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
select {
case e.Ch <- pkt:
default:
@@ -281,13 +281,13 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := tcpip.PacketBuffer{
+ source := stack.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
Data: payload.Clone(nil),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -295,7 +295,7 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []tcpip.PacketBuffer
+ var results []stack.PacketBuffer
L:
for {
select {
@@ -337,7 +337,7 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -459,7 +459,7 @@ func TestInvalidFragments(t *testing.T) {
s.CreateNIC(nicID, sniffer.New(ep))
for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
})
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 45dc757c7..8640feffc 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -27,7 +27,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
h := header.IPv6(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that up to
@@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
@@ -243,7 +243,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
}); err != nil {
sent.Dropped.Increment()
@@ -330,7 +330,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
copy(packet, h)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: pkt.Data,
}); err != nil {
@@ -463,7 +463,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
Header: hdr,
})
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 50c4b6474..bae09ed94 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -56,7 +56,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error {
return nil
}
@@ -66,7 +66,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) {
}
type stubLinkAddressCache struct {
@@ -187,7 +187,7 @@ func TestICMPCounts(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -326,7 +326,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{
Data: vv,
})
}
@@ -561,7 +561,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -738,7 +738,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -916,7 +916,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
})
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 9aef5234b..29e597002 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -112,7 +112,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
@@ -124,7 +124,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ e.HandlePacket(&loopedR, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -139,7 +139,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
@@ -161,14 +161,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
// TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
headerView := pkt.Data.First()
h := header.IPv6(headerView)
if !h.IsValid(pkt.Data.Size()) {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1cbfa7278..ed98ef22a 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -55,7 +55,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -113,7 +113,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index c9395de52..f924ed9e1 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -135,7 +135,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -238,7 +238,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -304,7 +304,7 @@ func TestHopLimitValidation(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(r, tcpip.PacketBuffer{
+ ep.HandlePacket(r, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -588,7 +588,7 @@ func TestRouterAdvertValidation(t *testing.T) {
t.Fatalf("got rxRA = %d, want = 0", got)
}
- e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 705cf01ee..8d80e9cee 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -18,11 +18,19 @@ go_template_instance(
go_library(
name = "stack",
srcs = [
+ "dhcpv6configurationfromndpra_string.go",
+ "forwarder.go",
"icmp_rate_limit.go",
+ "iptables.go",
+ "iptables_targets.go",
+ "iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
"ndp.go",
"nic.go",
+ "packet_buffer.go",
+ "packet_buffer_state.go",
+ "rand.go",
"registration.go",
"route.go",
"stack.go",
@@ -32,6 +40,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
@@ -39,7 +48,6 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
@@ -63,7 +71,6 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
@@ -79,6 +86,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
+ "forwarder_test.go",
"linkaddrcache_test.go",
"nic_test.go",
],
diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
new file mode 100644
index 000000000..8b4213eec
--- /dev/null
+++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[DHCPv6NoConfiguration-0]
+ _ = x[DHCPv6ManagedAddress-1]
+ _ = x[DHCPv6OtherConfigurations-2]
+}
+
+const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations"
+
+var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66}
+
+func (i DHCPv6ConfigurationFromNDPRA) String() string {
+ if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) {
+ return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
new file mode 100644
index 000000000..6b64cd37f
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // maxPendingResolutions is the maximum number of pending link-address
+ // resolutions.
+ maxPendingResolutions = 64
+ maxPendingPacketsPerResolution = 256
+)
+
+type pendingPacket struct {
+ nic *NIC
+ route *Route
+ proto tcpip.NetworkProtocolNumber
+ pkt PacketBuffer
+}
+
+type forwardQueue struct {
+ sync.Mutex
+
+ // The packets to send once the resolver completes.
+ packets map[<-chan struct{}][]*pendingPacket
+
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ cancelChans []chan struct{}
+}
+
+func newForwardQueue() *forwardQueue {
+ return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+}
+
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+ shouldWait := false
+
+ f.Lock()
+ packets, ok := f.packets[ch]
+ if !ok {
+ shouldWait = true
+ }
+ for len(packets) == maxPendingPacketsPerResolution {
+ p := packets[0]
+ packets = packets[1:]
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Release()
+ }
+ if l := len(packets); l >= maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+ }
+ f.packets[ch] = append(packets, &pendingPacket{
+ nic: n,
+ route: r,
+ proto: protocol,
+ pkt: pkt,
+ })
+ f.Unlock()
+
+ if !shouldWait {
+ return
+ }
+
+ // Wait for the link-address resolution to complete.
+ // Start a goroutine with a forwarding-cancel channel so that we can
+ // limit the maximum number of goroutines running concurrently.
+ cancel := f.newCancelChannel()
+ go func() {
+ cancelled := false
+ select {
+ case <-ch:
+ case <-cancel:
+ cancelled = true
+ }
+
+ f.Lock()
+ packets := f.packets[ch]
+ delete(f.packets, ch)
+ f.Unlock()
+
+ for _, p := range packets {
+ if cancelled {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else if _, err := p.route.Resolve(nil); err != nil {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else {
+ p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ }
+ p.route.Release()
+ }
+ }()
+}
+
+// newCancelChannel creates a channel that can cancel a pending forwarding
+// activity. The oldest channel is closed if the number of open channels would
+// exceed maxPendingResolutions.
+func (f *forwardQueue) newCancelChannel() chan struct{} {
+ f.Lock()
+ defer f.Unlock()
+
+ if len(f.cancelChans) == maxPendingResolutions {
+ ch := f.cancelChans[0]
+ f.cancelChans = f.cancelChans[1:]
+ close(ch)
+ }
+ if l := len(f.cancelChans); l >= maxPendingResolutions {
+ panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
+ }
+
+ ch := make(chan struct{})
+ f.cancelChans = append(f.cancelChans, ch)
+ return ch
+}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
new file mode 100644
index 000000000..c45c43d21
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -0,0 +1,635 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "encoding/binary"
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+const (
+ fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fwdTestNetHeaderLen = 12
+ fwdTestNetDefaultPrefixLen = 8
+
+ // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match
+ // the MTU of loopback interfaces on linux systems.
+ fwdTestNetDefaultMTU = 65536
+)
+
+// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
+// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
+// use the first three: destination address, source address, and transport
+// protocol. They're all one byte fields to simplify parsing.
+type fwdTestNetworkEndpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ proto *fwdTestNetworkProtocol
+ dispatcher TransportDispatcher
+ ep LinkEndpoint
+}
+
+func (f *fwdTestNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
+}
+
+func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
+ return f.nicID
+}
+
+func (f *fwdTestNetworkEndpoint) PrefixLen() int {
+ return f.prefixLen
+}
+
+func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
+}
+
+func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
+ return &f.id
+}
+
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
+ // Consume the network header.
+ b := pkt.Data.First()
+ pkt.Data.TrimFront(fwdTestNetHeaderLen)
+
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+}
+
+func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
+ return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+ return 0
+}
+
+func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return f.ep.Capabilities()
+}
+
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+ // Add the protocol's header to the packet and send it to the link
+ // endpoint.
+ b := pkt.Header.Prepend(fwdTestNetHeaderLen)
+ b[0] = r.RemoteAddress[0]
+ b[1] = f.id.LocalAddress[0]
+ b[2] = byte(params.Protocol)
+
+ return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (*fwdTestNetworkEndpoint) Close() {}
+
+// fwdTestNetworkProtocol is a network-layer protocol that implements Address
+// resolution.
+type fwdTestNetworkProtocol struct {
+ addrCache *linkAddrCache
+ addrResolveDelay time.Duration
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+}
+
+func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+ return fwdTestNetNumber
+}
+
+func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
+ return fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
+ return fwdTestNetDefaultPrefixLen
+}
+
+func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+}
+
+func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &fwdTestNetworkEndpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ proto: f,
+ dispatcher: dispatcher,
+ ep: ep,
+ }, nil
+}
+
+func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Close() {}
+
+func (f *fwdTestNetworkProtocol) Wait() {}
+
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+ if f.addrCache != nil && f.onLinkAddressResolved != nil {
+ time.AfterFunc(f.addrResolveDelay, func() {
+ f.onLinkAddressResolved(f.addrCache, addr)
+ })
+ }
+ return nil
+}
+
+func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if f.onResolveStaticAddress != nil {
+ return f.onResolveStaticAddress(addr)
+ }
+ return "", false
+}
+
+func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return fwdTestNetNumber
+}
+
+// fwdTestPacketInfo holds all the information about an outbound packet.
+type fwdTestPacketInfo struct {
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalLinkAddress tcpip.LinkAddress
+ Pkt PacketBuffer
+}
+
+type fwdTestLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+
+ // C is where outbound packets are queued.
+ C chan fwdTestPacketInfo
+}
+
+// InjectInbound injects an inbound packet.
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+ e.InjectLinkAddr(protocol, "", pkt)
+}
+
+// InjectLinkAddr injects an inbound packet with a remote link address.
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *fwdTestLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *fwdTestLinkEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ caps := LinkEndpointCapabilities(0)
+ return caps | CapabilityResolutionRequired
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
+ return 1 << 15
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header. Given it
+// doesn't have a header, it just returns 0.
+func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ RemoteLinkAddress: r.RemoteLinkAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ Pkt: pkt,
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// WritePackets stores outbound packets into the channel.
+func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ n := 0
+ for _, pkt := range pkts {
+ e.WritePacket(r, gso, protocol, pkt)
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ Pkt: PacketBuffer{Data: vv},
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*fwdTestLinkEndpoint) Wait() {}
+
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+ // Create a stack with the network protocol and two NICs.
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{proto},
+ })
+
+ proto.addrCache = s.linkAddrCache
+
+ // Enable forwarding.
+ s.SetForwarding(true)
+
+ // NIC 1 has the link address "a", and added the network address 1.
+ ep1 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "a",
+ }
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC #1 failed:", err)
+ }
+ if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress #1 failed:", err)
+ }
+
+ // NIC 2 has the link address "b", and added the network address 2.
+ ep2 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "b",
+ }
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatal("CreateNIC #2 failed:", err)
+ }
+ if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress #2 failed:", err)
+ }
+
+ // Route all packets to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
+ }
+
+ return ep1, ep2
+}
+
+func TestForwardingWithStaticResolver(t *testing.T) {
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolver(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any address will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithNoResolver(t *testing.T) {
+ // Create a network protocol without a resolver.
+ proto := &fwdTestNetworkProtocol{}
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ select {
+ case <-ep2.C:
+ t.Fatal("Packet should not be forwarded")
+ case <-time.After(time.Second):
+ }
+}
+
+func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ }
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[0] = 4
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := p.Pkt.Header.View()
+ if b[0] != 3 {
+ t.Fatalf("got b[0] = %d, want = 3", b[0])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := p.Pkt.Header.View()
+ if b[0] != 3 {
+ t.Fatalf("got b[0] = %d, want = 3", b[0])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := p.Pkt.Header.View()
+ if b[0] != 3 {
+ t.Fatalf("got b[0] = %d, want = 3", b[0])
+ }
+ // The first 5 packets should not be forwarded so the the
+ // sequemnce number should start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[0] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ b := p.Pkt.Header.View()
+ if b[0] < 8 {
+ t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/stack/iptables.go
index dbaccbb36..37907ae24 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package iptables supports packet filtering and manipulation via the iptables
-// tool.
-package iptables
+package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -135,6 +132,27 @@ func EmptyFilterTable() Table {
}
}
+// EmptyNatTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyNatTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ UserChains: map[string]int{},
+ }
+}
+
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
@@ -155,7 +173,7 @@ const (
// dropped.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
+func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
// Go through each table containing the hook.
for _, tablename := range it.Priorities[hook] {
table := it.Tables[tablename]
@@ -192,7 +210,7 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
}
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
@@ -237,12 +255,17 @@ func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, r
}
// Precondition: pk.NetworkHeader is set.
-func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
- // First check whether the packet matches the IP header filter.
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
- if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
+ // If pkt.NetworkHeader hasn't been set yet, it will be contained in
+ // pkt.Data.First().
+ if pkt.NetworkHeader == nil {
+ pkt.NetworkHeader = pkt.Data.First()
+ }
+
+ // Check whether the packet matches the IP header filter.
+ if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -263,3 +286,26 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
// All the matchers matched, so run the target.
return rule.Target.Action(pkt)
}
+
+func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool {
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ // Check the transport protocol.
+ if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() {
+ return false
+ }
+
+ // Check the destination IP.
+ dest := hdr.DestinationAddress()
+ matches := true
+ for i := range filter.Dst {
+ if dest[i]&filter.DstMask[i] != filter.Dst[i] {
+ matches = false
+ break
+ }
+ }
+ if matches == filter.DstInvert {
+ return false
+ }
+
+ return true
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
new file mode 100644
index 000000000..7b4543caf
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -0,0 +1,144 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// AcceptTarget accepts packets.
+type AcceptTarget struct{}
+
+// Action implements Target.Action.
+func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ return RuleAccept, 0
+}
+
+// DropTarget drops packets.
+type DropTarget struct{}
+
+// Action implements Target.Action.
+func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ return RuleDrop, 0
+}
+
+// ErrorTarget logs an error and drops the packet. It represents a target that
+// should be unreachable.
+type ErrorTarget struct{}
+
+// Action implements Target.Action.
+func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ log.Debugf("ErrorTarget triggered.")
+ return RuleDrop, 0
+}
+
+// UserChainTarget marks a rule as the beginning of a user chain.
+type UserChainTarget struct {
+ Name string
+}
+
+// Action implements Target.Action.
+func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) {
+ panic("UserChainTarget should never be called.")
+}
+
+// ReturnTarget returns from the current chain. If the chain is a built-in, the
+// hook's underflow should be called.
+type ReturnTarget struct{}
+
+// Action implements Target.Action.
+func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) {
+ return RuleReturn, 0
+}
+
+// RedirectTarget redirects the packet by modifying the destination port/IP.
+// Min and Max values for IP and Ports in the struct indicate the range of
+// values which can be used to redirect.
+type RedirectTarget struct {
+ // TODO(gvisor.dev/issue/170): Other flags need to be added after
+ // we support them.
+ // RangeProtoSpecified flag indicates single port is specified to
+ // redirect.
+ RangeProtoSpecified bool
+
+ // Min address used to redirect.
+ MinIP tcpip.Address
+
+ // Max address used to redirect.
+ MaxIP tcpip.Address
+
+ // Min port used to redirect.
+ MinPort uint16
+
+ // Max port used to redirect.
+ MaxPort uint16
+}
+
+// Action implements Target.Action.
+// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
+// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// of which should be the case.
+func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
+ newPkt := pkt.Clone()
+
+ // Set network header.
+ headerView := newPkt.Data.First()
+ netHeader := header.IPv4(headerView)
+ newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize]
+
+ hlen := int(netHeader.HeaderLength())
+ tlen := int(netHeader.TotalLength())
+ newPkt.Data.TrimFront(hlen)
+ newPkt.Data.CapLength(tlen - hlen)
+
+ // TODO(gvisor.dev/issue/170): Change destination address to
+ // loopback or interface address on which the packet was
+ // received.
+
+ // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
+ // we need to change dest address (for OUTPUT chain) or ports.
+ switch protocol := netHeader.TransportProtocol(); protocol {
+ case header.UDPProtocolNumber:
+ var udpHeader header.UDP
+ if newPkt.TransportHeader != nil {
+ udpHeader = header.UDP(newPkt.TransportHeader)
+ } else {
+ if len(pkt.Data.First()) < header.UDPMinimumSize {
+ return RuleDrop, 0
+ }
+ udpHeader = header.UDP(newPkt.Data.First())
+ }
+ udpHeader.SetDestinationPort(rt.MinPort)
+ case header.TCPProtocolNumber:
+ var tcpHeader header.TCP
+ if newPkt.TransportHeader != nil {
+ tcpHeader = header.TCP(newPkt.TransportHeader)
+ } else {
+ if len(pkt.Data.First()) < header.TCPMinimumSize {
+ return RuleDrop, 0
+ }
+ tcpHeader = header.TCP(newPkt.TransportHeader)
+ }
+ // TODO(gvisor.dev/issue/170): Need to recompute checksum
+ // and implement nat connection tracking to support TCP.
+ tcpHeader.SetDestinationPort(rt.MinPort)
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/stack/iptables_types.go
index 7d032fd23..2ffb55f2a 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package iptables
+package stack
import (
"gvisor.dev/gvisor/pkg/tcpip"
@@ -144,6 +144,18 @@ type Rule struct {
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
+
+ // Dst matches the destination IP address.
+ Dst tcpip.Address
+
+ // DstMask masks bits of the destination IP address when comparing with
+ // Dst.
+ DstMask tcpip.Address
+
+ // DstInvert inverts the meaning of the destination IP check, i.e. when
+ // true the filter will match packets that fail the destination
+ // comparison.
+ DstInvert bool
}
// A Matcher is the interface for matching packets.
@@ -156,7 +168,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
@@ -164,5 +176,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet tcpip.PacketBuffer) (RuleVerdict, int)
+ Action(packet PacketBuffer) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index f651871ce..630fdefc5 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -361,16 +361,16 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
+ // The timer used to send the next router solicitation message.
+ rtrSolicitTimer *time.Timer
+
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
- // The timer used to send the next router solicitation message.
- // If routers are being solicited, rtrSolicitTimer MUST NOT be nil.
- rtrSolicitTimer *time.Timer
-
- // The addresses generated by SLAAC.
- autoGenAddresses map[tcpip.Address]autoGenAddressState
+ // The SLAAC prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ slaacPrefixes map[tcpip.Subnet]slaacPrefixState
// The last learned DHCPv6 configuration from an NDP RA.
dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
@@ -402,18 +402,16 @@ type onLinkPrefixState struct {
invalidationTimer tcpip.CancellableTimer
}
-// autoGenAddressState holds data associated with an address generated via
-// SLAAC.
-type autoGenAddressState struct {
- // A reference to the referencedNetworkEndpoint that this autoGenAddressState
- // is holding state for.
- ref *referencedNetworkEndpoint
-
+// slaacPrefixState holds state associated with a SLAAC prefix.
+type slaacPrefixState struct {
deprecationTimer tcpip.CancellableTimer
invalidationTimer tcpip.CancellableTimer
// Nonzero only when the address is not valid forever.
validUntil time.Time
+
+ // The prefix's permanent address endpoint.
+ ref *referencedNetworkEndpoint
}
// startDuplicateAddressDetection performs Duplicate Address Detection.
@@ -566,7 +564,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, tcpip.PacketBuffer{Header: hdr},
+ }, PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
return err
@@ -899,23 +897,15 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
prefix := pi.Subnet()
- // Check if we already have an auto-generated address for prefix.
- for addr, addrState := range ndp.autoGenAddresses {
- refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()}
- if refAddrWithPrefix.Subnet() != prefix {
- continue
- }
-
- // At this point, we know we are refreshing a SLAAC generated IPv6 address
- // with the prefix prefix. Do the work as outlined by RFC 4862 section
- // 5.5.3.e.
- ndp.refreshAutoGenAddressLifetimes(addr, pl, vl)
+ // Check if we already maintain SLAAC state for prefix.
+ if _, ok := ndp.slaacPrefixes[prefix]; ok {
+ // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
+ ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl)
return
}
- // We do not already have an address with the prefix prefix. Do the
- // work as outlined by RFC 4862 section 5.5.3.d if n is configured
- // to auto-generate global addresses by SLAAC.
+ // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC.
if !ndp.configs.AutoGenGlobalAddresses {
return
}
@@ -927,6 +917,8 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
// for prefix.
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// If we do not already have an address for this prefix and the valid
// lifetime is 0, no need to do anything further, as per RFC 4862
@@ -942,9 +934,59 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
return
}
+ // If the preferred lifetime is zero, then the prefix should be considered
+ // deprecated.
+ deprecated := pl == 0
+ ref := ndp.addSLAACAddr(prefix, deprecated)
+ if ref == nil {
+ // We were unable to generate a permanent address for prefix so do nothing
+ // further as there is no reason to maintain state for a SLAAC prefix we
+ // cannot generate a permanent address for.
+ return
+ }
+
+ state := slaacPrefixState{
+ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ log.Fatalf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix)
+ }
+
+ ndp.deprecateSLAACAddress(prefixState.ref)
+ }),
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateSLAACPrefix(prefix, true)
+ }),
+ ref: ref,
+ }
+
+ // Setup the initial timers to deprecate and invalidate prefix.
+
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ state.deprecationTimer.Reset(pl)
+ }
+
+ if vl < header.NDPInfiniteLifetime {
+ state.invalidationTimer.Reset(vl)
+ state.validUntil = time.Now().Add(vl)
+ }
+
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// addSLAACAddr adds a SLAAC address for prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referencedNetworkEndpoint {
addrBytes := []byte(prefix.ID())
if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
- addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey)
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ 0, /* dadCounter */
+ oIID.SecretKey,
+ )
} else {
// Only attempt to generate an interface-specific IID if we have a valid
// link address.
@@ -953,137 +995,103 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// LinkEndpoint.LinkAddress) before reaching this point.
linkAddr := ndp.nic.linkEP.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return
+ return nil
}
// Generate an address within prefix from the modified EUI-64 of ndp's NIC's
// Ethernet MAC address.
header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
}
- addr := tcpip.Address(addrBytes)
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
+
+ generatedAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
+ },
}
// If the nic already has this address, do nothing further.
- if ndp.nic.hasPermanentAddrLocked(addr) {
- return
+ if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) {
+ return nil
}
// Inform the integrator that we have a new SLAAC address.
ndpDisp := ndp.nic.stack.ndpDisp
if ndpDisp == nil {
- return
+ return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) {
+
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) {
// Informed by the integrator not to add the address.
- return
+ return nil
}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix,
- }
- // If the preferred lifetime is zero, then the address should be considered
- // deprecated.
- deprecated := pl == 0
- ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
+ ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
if err != nil {
- log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err)
- }
-
- state := autoGenAddressState{
- ref: ref,
- deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- addrState, ok := ndp.autoGenAddresses[addr]
- if !ok {
- log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)
- }
- addrState.ref.deprecated = true
- ndp.notifyAutoGenAddressDeprecated(addr)
- }),
- invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- ndp.invalidateAutoGenAddress(addr)
- }),
- }
-
- // Setup the initial timers to deprecate and invalidate this newly generated
- // address.
-
- if !deprecated && pl < header.NDPInfiniteLifetime {
- state.deprecationTimer.Reset(pl)
+ log.Fatalf("ndp: error when adding address %+v: %s", generatedAddr, err)
}
- if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
- state.validUntil = time.Now().Add(vl)
- }
-
- ndp.autoGenAddresses[addr] = state
+ return ref
}
-// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated
-// address addr.
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
-func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) {
- addrState, ok := ndp.autoGenAddresses[addr]
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
- log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr)
+ log.Fatalf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)
}
- defer func() { ndp.autoGenAddresses[addr] = addrState }()
+ defer func() { ndp.slaacPrefixes[prefix] = prefixState }()
- // If the preferred lifetime is zero, then the address should be considered
- // deprecated.
+ // If the preferred lifetime is zero, then the prefix should be deprecated.
deprecated := pl == 0
- wasDeprecated := addrState.ref.deprecated
- addrState.ref.deprecated = deprecated
-
- // Only send the deprecation event if the deprecated status for addr just
- // changed from non-deprecated to deprecated.
- if !wasDeprecated && deprecated {
- ndp.notifyAutoGenAddressDeprecated(addr)
+ if deprecated {
+ ndp.deprecateSLAACAddress(prefixState.ref)
+ } else {
+ prefixState.ref.deprecated = false
}
- // If addr was preferred for some finite lifetime before, stop the deprecation
- // timer so it can be reset.
- addrState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, stop the
+ // deprecation timer so it can be reset.
+ prefixState.deprecationTimer.StopLocked()
- // Reset the deprecation timer if addr has a finite preferred lifetime.
+ // Reset the deprecation timer if prefix has a finite preferred lifetime.
if !deprecated && pl < header.NDPInfiniteLifetime {
- addrState.deprecationTimer.Reset(pl)
+ prefixState.deprecationTimer.Reset(pl)
}
- // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address
- //
+ // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
//
// 1) If the received Valid Lifetime is greater than 2 hours or greater than
- // RemainingLifetime, set the valid lifetime of the address to the
+ // RemainingLifetime, set the valid lifetime of the prefix to the
// advertised Valid Lifetime.
//
// 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
// advertised Valid Lifetime.
//
- // 3) Otherwise, reset the valid lifetime of the address to 2 hours.
+ // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
// Handle the infinite valid lifetime separately as we do not keep a timer in
// this case.
if vl >= header.NDPInfiniteLifetime {
- addrState.invalidationTimer.StopLocked()
- addrState.validUntil = time.Time{}
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.validUntil = time.Time{}
return
}
var effectiveVl time.Duration
var rl time.Duration
- // If the address was originally set to be valid forever, assume the remaining
+ // If the prefix was originally set to be valid forever, assume the remaining
// time to be the maximum possible value.
- if addrState.validUntil == (time.Time{}) {
+ if prefixState.validUntil == (time.Time{}) {
rl = header.NDPInfiniteLifetime
} else {
- rl = time.Until(addrState.validUntil)
+ rl = time.Until(prefixState.validUntil)
}
if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
@@ -1094,58 +1102,66 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t
effectiveVl = MinPrefixInformationValidLifetimeForUpdate
}
- addrState.invalidationTimer.StopLocked()
- addrState.invalidationTimer.Reset(effectiveVl)
- addrState.validUntil = time.Now().Add(effectiveVl)
-}
-
-// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr
-// has been deprecated.
-func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
- })
- }
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.validUntil = time.Now().Add(effectiveVl)
}
-// invalidateAutoGenAddress invalidates an auto-generated address.
+// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
+// dispatcher that ref has been deprecated.
+//
+// deprecateSLAACAddress does nothing if ref is already deprecated.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
- if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) {
+func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
+ if ref.deprecated {
return
}
- ndp.nic.removePermanentAddressLocked(addr)
+ ref.deprecated = true
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ })
+ }
}
-// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated
-// address's resources from ndp. If the stack has an NDP dispatcher, it will
-// be notified that addr has been invalidated.
-//
-// Returns true if ndp had resources for addr to cleanup.
+// invalidateSLAACPrefix invalidates a SLAAC prefix.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
- state, ok := ndp.autoGenAddresses[addr]
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) {
+ state, ok := ndp.slaacPrefixes[prefix]
if !ok {
- return false
+ return
}
state.deprecationTimer.StopLocked()
state.invalidationTimer.StopLocked()
- delete(ndp.autoGenAddresses, addr)
+ delete(ndp.slaacPrefixes, prefix)
+
+ addr := state.ref.ep.ID().LocalAddress
+
+ if removeAddr {
+ if err := ndp.nic.removePermanentAddressLocked(addr); err != nil {
+ log.Fatalf("ndp: removePermanentAddressLocked(%s): %s", addr, err)
+ }
+ }
if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
+ PrefixLen: state.ref.ep.PrefixLen(),
})
}
+}
- return true
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC
+// address's resources from ndp.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix) {
+ ndp.invalidateSLAACPrefix(addr.Subnet(), false)
}
// cleanupState cleans up ndp's state.
@@ -1163,21 +1179,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- linkLocalAddrs := 0
- for addr := range ndp.autoGenAddresses {
+ linkLocalPrefixes := 0
+ for prefix := range ndp.slaacPrefixes {
// RFC 4862 section 5 states that routers are also expected to generate a
// link-local address so we do not invalidate them if we are cleaning up
// host-only state.
- if hostOnly && linkLocalSubnet.Contains(addr) {
- linkLocalAddrs++
+ if hostOnly && prefix == linkLocalSubnet {
+ linkLocalPrefixes++
continue
}
- ndp.invalidateAutoGenAddress(addr)
+ ndp.invalidateSLAACPrefix(prefix, true)
}
- if got := len(ndp.autoGenAddresses); got != linkLocalAddrs {
- log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs)
+ if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
+ log.Fatalf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)
}
for prefix := range ndp.onLinkPrefixes {
@@ -1220,9 +1236,15 @@ func (ndp *ndpState) startSolicitingRouters() {
}
ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
- // Send an RS message with the unspecified source address.
- ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ // As per RFC 4861 section 4.1, the source of the RS is an address assigned
+ // to the sending interface, or the unspecified address if no address is
+ // assigned to the sending interface.
+ ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress)
+ if ref == nil {
+ ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ }
+ localAddr := ref.ep.ID().LocalAddress
+ r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
// Route should resolve immediately since
@@ -1234,10 +1256,25 @@ func (ndp *ndpState) startSolicitingRouters() {
log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())
}
- payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
+ // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+ // link-layer address option if the source address of the NDP RS is
+ // specified. This option MUST NOT be included if the source address is
+ // unspecified.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ var optsSerializer header.NDPOptionsSerializer
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ optsSerializer = header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ }
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
pkt.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(pkt.NDPPayload())
+ rs.Options().Serialize(optsSerializer)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
sent := r.Stats().ICMP.V6PacketsSent
@@ -1246,7 +1283,7 @@ func (ndp *ndpState) startSolicitingRouters() {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, tcpip.PacketBuffer{Header: hdr},
+ }, PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6e9306d09..06edd05b6 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -602,7 +602,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -639,8 +639,9 @@ func TestDADStop(t *testing.T) {
const nicID = 1
tests := []struct {
- name string
- stopFn func(t *testing.T, s *stack.Stack)
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ skipFinalAddrCheck bool
}{
// Tests to make sure that DAD stops when an address is removed.
{
@@ -661,6 +662,19 @@ func TestDADStop(t *testing.T) {
}
},
},
+
+ // Tests to make sure that DAD stops when the NIC is removed.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ // The NIC is removed so we can't check its addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
}
for _, test := range tests {
@@ -710,12 +724,15 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+
+ if !test.skipFinalAddrCheck {
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
}
// Should not have sent more than 1 NS message.
@@ -901,7 +918,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -936,14 +953,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
//
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
}
@@ -952,7 +969,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer {
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
}
@@ -960,7 +977,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
+func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
@@ -969,7 +986,7 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
//
// Note, raBufWithPI does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer {
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer {
flags := uint8(0)
if onLink {
// The OnLink flag is the 7th bit in the flags byte.
@@ -2983,11 +3000,12 @@ func TestCleanupNDPState(t *testing.T) {
cleanupFn func(t *testing.T, s *stack.Stack)
keepAutoGenLinkLocal bool
maxAutoGenAddrEvents int
+ skipFinalAddrCheck bool
}{
// A NIC should still keep its auto-generated link-local address when
// becoming a router.
{
- name: "Forwarding Enable",
+ name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(true)
@@ -2998,7 +3016,7 @@ func TestCleanupNDPState(t *testing.T) {
// A NIC should cleanup all NDP state when it is disabled.
{
- name: "NIC Disable",
+ name: "Disable NIC",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@@ -3012,6 +3030,26 @@ func TestCleanupNDPState(t *testing.T) {
keepAutoGenLinkLocal: false,
maxAutoGenAddrEvents: 6,
},
+
+ // A NIC should cleanup all NDP state when it is removed.
+ {
+ name: "Remove NIC",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.RemoveNIC(nicID1); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err)
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ // The NICs are removed so we can't check their addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
}
for _, test := range tests {
@@ -3230,35 +3268,37 @@ func TestCleanupNDPState(t *testing.T) {
t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
}
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ if !test.skipFinalAddrCheck {
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
}
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
// Should not get any more events (invalidation timers should have been
@@ -3384,6 +3424,10 @@ func TestRouterSolicitation(t *testing.T) {
tests := []struct {
name string
linkHeaderLen uint16
+ linkAddr tcpip.LinkAddress
+ nicAddr tcpip.Address
+ expectedSrcAddr tcpip.Address
+ expectedNDPOpts []header.NDPOption
maxRtrSolicit uint8
rtrSolicitInt time.Duration
effectiveRtrSolicitInt time.Duration
@@ -3392,6 +3436,7 @@ func TestRouterSolicitation(t *testing.T) {
}{
{
name: "Single RS with delay",
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 1,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3401,6 +3446,8 @@ func TestRouterSolicitation(t *testing.T) {
{
name: "Two RS with delay",
linkHeaderLen: 1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3408,8 +3455,14 @@ func TestRouterSolicitation(t *testing.T) {
effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
},
{
- name: "Single RS without delay",
- linkHeaderLen: 2,
+ name: "Single RS without delay",
+ linkHeaderLen: 2,
+ linkAddr: linkAddr1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
+ expectedNDPOpts: []header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ },
maxRtrSolicit: 1,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3419,6 +3472,8 @@ func TestRouterSolicitation(t *testing.T) {
{
name: "Two RS without delay and invalid zero interval",
linkHeaderLen: 3,
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 2,
rtrSolicitInt: 0,
effectiveRtrSolicitInt: 4 * time.Second,
@@ -3427,6 +3482,8 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Three RS without delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 3,
rtrSolicitInt: 500 * time.Millisecond,
effectiveRtrSolicitInt: 500 * time.Millisecond,
@@ -3435,6 +3492,8 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS with invalid negative delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3457,7 +3516,7 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
headerLength: test.linkHeaderLen,
}
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -3481,10 +3540,10 @@ func TestRouterSolicitation(t *testing.T) {
checker.IPv6(t,
p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
+ checker.SrcAddr(test.expectedSrcAddr),
checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
checker.TTL(header.NDPHopLimit),
- checker.NDPRS(),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
)
if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
@@ -3510,13 +3569,19 @@ func TestRouterSolicitation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- // Make sure each RS got sent at the right
- // times.
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ }
+ }
+
+ // Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining > 0 {
waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
remaining--
}
+
for ; remaining > 0; remaining-- {
waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
waitForPkt(defaultAsyncEventTimeout)
@@ -3550,17 +3615,19 @@ func TestStopStartSolicitingRouters(t *testing.T) {
tests := []struct {
name string
startFn func(t *testing.T, s *stack.Stack)
- stopFn func(t *testing.T, s *stack.Stack)
+ // first is used to tell stopFn that it is being called for the first time
+ // after router solicitations were last enabled.
+ stopFn func(t *testing.T, s *stack.Stack, first bool)
}{
// Tests that when forwarding is enabled or disabled, router solicitations
// are stopped or started, respectively.
{
- name: "Forwarding enabled and disabled",
+ name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(false)
},
- stopFn: func(t *testing.T, s *stack.Stack) {
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
s.SetForwarding(true)
},
@@ -3569,7 +3636,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Tests that when a NIC is enabled or disabled, router solicitations
// are started or stopped, respectively.
{
- name: "NIC disabled and enabled",
+ name: "Enable and disable NIC",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@@ -3577,7 +3644,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
},
- stopFn: func(t *testing.T, s *stack.Stack) {
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
if err := s.DisableNIC(nicID); err != nil {
@@ -3585,6 +3652,25 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
},
},
+
+ // Tests that when a NIC is removed, router solicitations are stopped. We
+ // cannot start router solications on a removed NIC.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack, first bool) {
+ t.Helper()
+
+ // Only try to remove the NIC the first time stopFn is called since it's
+ // impossible to remove an already removed NIC.
+ if !first {
+ return
+ }
+
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
for _, test := range tests {
@@ -3623,7 +3709,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
// Stop soliciting routers.
- test.stopFn(t, s)
+ test.stopFn(t, s, true /* first */)
ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
@@ -3637,13 +3723,18 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
- test.stopFn(t, s)
+ test.stopFn(t, s, false /* first */)
ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
}
+ // If test.startFn is nil, there is no way to restart router solications.
+ if test.startFn == nil {
+ return
+ }
+
// Start soliciting routers.
test.startFn(t, s)
waitForPkt(delay + defaultAsyncEventTimeout)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 46d3a6646..b6fa647ea 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"log"
"reflect"
"sort"
@@ -54,7 +55,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
+ mcastJoins map[NetworkEndpointID]uint32
// packetEPs is protected by mu, but the contained PacketEndpoint
// values are not.
packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
@@ -121,15 +122,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
- nic.mu.mcastJoins = make(map[NetworkEndpointID]int32)
+ nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
nic.mu.ndp = ndpState{
- nic: nic,
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
+ nic: nic,
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
}
// Register supported packet endpoint protocols.
@@ -165,8 +166,17 @@ func (n *NIC) disable() *tcpip.Error {
}
n.mu.Lock()
- defer n.mu.Unlock()
+ err := n.disableLocked()
+ n.mu.Unlock()
+ return err
+}
+// disableLocked disables n.
+//
+// It undoes the work done by enable.
+//
+// n MUST be locked.
+func (n *NIC) disableLocked() *tcpip.Error {
if !n.mu.enabled {
return nil
}
@@ -189,7 +199,7 @@ func (n *NIC) disable() *tcpip.Error {
}
// The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@@ -305,24 +315,33 @@ func (n *NIC) remove() *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- // Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
+ n.disableLocked()
+
+ // TODO(b/151378115): come up with a better way to pick an error than the
+ // first one.
+ var err *tcpip.Error
+
+ // Forcefully leave multicast groups.
+ for nid := range n.mu.mcastJoins {
+ if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
+ err = tempErr
+ }
+ }
// Remove permanent and permanentTentative addresses, so no packet goes out.
- var errs []*tcpip.Error
for nid, ref := range n.mu.endpoints {
switch ref.getKind() {
case permanentTentative, permanent:
- if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
- errs = append(errs, err)
+ if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
+ err = tempErr
}
}
}
- if len(errs) > 0 {
- return errs[0]
- }
- return nil
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ return err
}
// becomeIPv6Router transitions n into an IPv6 router.
@@ -451,7 +470,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
for _, r := range primaryAddrs {
// If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoing() {
+ if !r.isValidForOutgoingRLocked() {
continue
}
@@ -969,6 +988,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
for i, ref := range refs {
if ref == r {
n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ refs[len(refs)-1] = nil
break
}
}
@@ -996,8 +1016,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
if isIPv6Unicast {
- // If we are removing a tentative IPv6 unicast address, stop
- // DAD.
+ // If we are removing a tentative IPv6 unicast address, stop DAD.
if kind == permanentTentative {
n.mu.ndp.stopDuplicateAddressDetection(addr)
}
@@ -1005,7 +1024,10 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
if r.configType == slaac {
- n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
+ n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: r.ep.PrefixLen(),
+ })
}
}
@@ -1019,9 +1041,12 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
+ //
+ // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
+ // multicast group may be left by user action.
if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(addr)
- if err := n.leaveGroupLocked(snmc); err != nil {
+ if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@@ -1081,26 +1106,31 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.leaveGroupLocked(addr)
+ return n.leaveGroupLocked(addr, false /* force */)
}
// leaveGroupLocked decrements the count for the given multicast address, and
// when it reaches zero removes the endpoint for this address. n MUST be locked
// before leaveGroupLocked is called.
-func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+//
+// If force is true, then the count for the multicast addres is ignored and the
+// endpoint will be removed immediately.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
id := NetworkEndpointID{addr}
- joins := n.mu.mcastJoins[id]
- switch joins {
- case 0:
+ joins, ok := n.mu.mcastJoins[id]
+ if !ok {
// There are no joins with this address on this NIC.
return tcpip.ErrBadLocalAddress
- case 1:
- // This is the last one, clean up.
- if err := n.removePermanentAddressLocked(addr); err != nil {
- return err
- }
}
- n.mu.mcastJoins[id] = joins - 1
+
+ joins--
+ if force || joins == 0 {
+ // There are no outstanding joins or we are forced to leave, clean up.
+ delete(n.mu.mcastJoins, id)
+ return n.removePermanentAddressLocked(addr)
+ }
+
+ n.mu.mcastJoins[id] = joins
return nil
}
@@ -1113,9 +1143,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return joins != 0
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
+
ref.ep.HandlePacket(&r, pkt)
ref.decRef()
}
@@ -1126,7 +1157,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
n.mu.RLock()
enabled := n.mu.enabled
// If the NIC is not yet enabled, don't receive any packets.
@@ -1186,6 +1217,16 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+
+ // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
+ if protocol == header.IPv4ProtocolNumber {
+ ipt := n.stack.IPTables()
+ if ok := ipt.Check(Prerouting, pkt); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+ }
+
if ref := n.getRef(protocol, dst); ref != nil {
handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
return
@@ -1201,10 +1242,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
- defer r.Release()
-
- r.LocalLinkAddress = n.linkEP.LinkAddress()
- r.RemoteLinkAddress = remote
// Found a NIC.
n := r.ref.nic
@@ -1213,24 +1250,33 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
n.mu.RUnlock()
if ok {
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, pkt)
ref.decRef()
- } else {
- // n doesn't have a destination endpoint.
- // Send the packet out of n.
- pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
- pkt.Data.RemoveFirst()
-
- // TODO(b/128629022): use route.WritePacket.
- if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+ r.Release()
+ return
+ }
+
+ // n doesn't have a destination endpoint.
+ // Send the packet out of n.
+ // TODO(b/128629022): move this logic to route.WritePacket.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
+ // forwarder will release route.
+ return
}
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ r.Release()
+ return
}
+
+ // The link-address resolution finished immediately.
+ n.forwardPacket(&r, protocol, pkt)
+ r.Release()
return
}
@@ -1240,9 +1286,38 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
}
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+
+ firstData := pkt.Data.First()
+ pkt.Data.RemoveFirst()
+
+ if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
+ pkt.Header = buffer.NewPrependableFromView(firstData)
+ } else {
+ firstDataLen := len(firstData)
+
+ // pkt.Header should have enough capacity to hold n.linkEP's headers.
+ pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
+
+ // TODO(b/151227689): avoid copying the packet when forwarding
+ if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
+ }
+ }
+
+ if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+}
+
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -1288,7 +1363,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index edaee3b86..d672fc157 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -17,7 +17,6 @@ package stack
import (
"testing"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -45,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index ab24372e7..9505a4e92 100644
--- a/pkg/tcpip/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcpip
+package stack
import "gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -55,6 +55,10 @@ type PacketBuffer struct {
LinkHeader buffer.View
NetworkHeader buffer.View
TransportHeader buffer.View
+
+ // Hash is the transport layer hash of this packet. A value of zero
+ // indicates no valid hash has been set.
+ Hash uint32
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go
index ad3cc24fa..0c6b7924c 100644
--- a/pkg/tcpip/packet_buffer_state.go
+++ b/pkg/tcpip/stack/packet_buffer_state.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcpip
+package stack
import "gvisor.dev/gvisor/pkg/tcpip/buffer"
diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go
new file mode 100644
index 000000000..421fb5c15
--- /dev/null
+++ b/pkg/tcpip/stack/rand.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ mathrand "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// lockedRandomSource provides a threadsafe rand.Source.
+type lockedRandomSource struct {
+ mu sync.Mutex
+ src mathrand.Source
+}
+
+func (r *lockedRandomSource) Int63() (n int64) {
+ r.mu.Lock()
+ n = r.src.Int63()
+ r.mu.Unlock()
+ return n
+}
+
+func (r *lockedRandomSource) Seed(seed int64) {
+ r.mu.Lock()
+ r.src.Seed(seed)
+ r.mu.Unlock()
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index f9fd8f18f..ac043b722 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -67,12 +67,12 @@ type TransportEndpoint interface {
// this transport endpoint. It sets pkt.TransportHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -100,7 +100,7 @@ type RawTransportEndpoint interface {
// layer up.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, pkt PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -118,7 +118,7 @@ type PacketEndpoint interface {
// should construct its own ethernet header for applications.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -150,7 +150,7 @@ type TransportProtocol interface {
// stats purposes only).
//
// HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -180,7 +180,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -189,7 +189,7 @@ type TransportDispatcher interface {
// DeliverTransportControlPacket.
//
// DeliverTransportControlPacket takes ownership of pkt.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -242,15 +242,15 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -265,7 +265,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, pkt PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -322,7 +322,7 @@ type NetworkDispatcher interface {
// packets sent via loopback), and won't have the field set.
//
// DeliverNetworkPacket takes ownership of pkt.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -354,7 +354,7 @@ const (
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
// out through the implementer's data link endpoint. When a link header exists,
-// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the
+// it sets each PacketBuffer's LinkHeader field before passing it up the
// stack.
type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
@@ -385,7 +385,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error
// WritePackets writes packets with the given protocol through the
// given route. pkts must not be zero length.
@@ -393,7 +393,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header.
@@ -401,6 +401,9 @@ type LinkEndpoint interface {
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
+ //
+ // Attach will be called with a nil dispatcher if the receiver's associated
+ // NIC is being removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
@@ -423,7 +426,7 @@ type InjectableLinkEndpoint interface {
LinkEndpoint
// InjectInbound injects an inbound packet.
- InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
// InjectOutbound writes a fully formed outbound packet directly to the
// link.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index f565aafb2..9fbe8a411 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
@@ -169,7 +169,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack
}
// WritePackets writes the set of packets through the given route.
-func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
if !r.ref.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -190,7 +190,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index ebb6c5e3b..41398a1b6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,7 +20,9 @@
package stack
import (
+ "bytes"
"encoding/binary"
+ mathrand "math/rand"
"sync/atomic"
"time"
@@ -31,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
@@ -51,7 +52,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -428,7 +429,7 @@ type Stack struct {
// tables are the iptables packet filtering and manipulation rules. The are
// protected by tablesMu.`
- tables iptables.IPTables
+ tables IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -462,6 +463,14 @@ type Stack struct {
// opaqueIIDOpts hold the options for generating opaque interface identifiers
// (IIDs) as outlined by RFC 7217.
opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // forwarder holds the packets that wait for their link-address resolutions
+ // to complete, and forwards them when each resolution is done.
+ forwarder *forwardQueue
+
+ // randomGenerator is an injectable pseudo random generator that can be
+ // used when a random number is required.
+ randomGenerator *mathrand.Rand
}
// UniqueID is an abstract generator of unique identifiers.
@@ -522,9 +531,16 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
- // OpaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // RandSource is an optional source to use to generate random
+ // numbers. If omitted it defaults to a Source seeded by the data
+ // returned by rand.Read().
+ //
+ // RandSource must be thread-safe.
+ RandSource mathrand.Source
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -551,11 +567,13 @@ type TransportEndpointInfo struct {
RegisterNICID tcpip.NICID
}
-// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address
-// and returns the network protocol number to be used to communicate with the
-// specified address. It returns an error if the passed address is incompatible
-// with the receiver.
-func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6
+// address and returns the network protocol number to be used to communicate
+// with the specified address. It returns an error if the passed address is
+// incompatible with the receiver.
+//
+// Preconditon: the parent endpoint mu must be held while calling this method.
+func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
@@ -618,6 +636,13 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ randSrc := opts.RandSource
+ if randSrc == nil {
+ // Source provided by mathrand.NewSource is not thread-safe so
+ // we wrap it in a simple thread-safe version.
+ randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
@@ -639,6 +664,8 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
opaqueIIDOpts: opts.OpaqueIIDOpts,
+ forwarder: newForwardQueue(),
+ randomGenerator: mathrand.New(randSrc),
}
// Add specified network protocols.
@@ -731,7 +758,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -1694,7 +1721,7 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool,
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() iptables.IPTables {
+func (s *Stack) IPTables() IPTables {
s.tablesMu.RLock()
t := s.tables
s.tablesMu.RUnlock()
@@ -1702,7 +1729,7 @@ func (s *Stack) IPTables() iptables.IPTables {
}
// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt iptables.IPTables) {
+func (s *Stack) SetIPTables(ipt IPTables) {
s.tablesMu.Lock()
s.tables = ipt
s.tablesMu.Unlock()
@@ -1812,6 +1839,12 @@ func (s *Stack) Seed() uint32 {
return s.seed
}
+// Rand returns a reference to a pseudo random generator that can be used
+// to generate random numbers as required.
+func (s *Stack) Rand() *mathrand.Rand {
+ return s.randomGenerator
+}
+
func generateRandUint32() uint32 {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
@@ -1819,3 +1852,16 @@ func generateRandUint32() uint32 {
}
return binary.LittleEndian.Uint32(b)
}
+
+func generateRandInt64() int64 {
+ b := make([]byte, 8)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ buf := bytes.NewReader(b)
+ var v int64
+ if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
+ return v
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e15db40fb..555fcd92f 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
@@ -126,7 +126,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, tcpip.PacketBuffer{
+ f.HandlePacket(r, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}
@@ -153,11 +153,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -255,7 +255,7 @@ type linkEPWithMockedAttach struct {
// Attach implements stack.LinkEndpoint.Attach.
func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
l.LinkEndpoint.Attach(d)
- l.attached = true
+ l.attached = d != nil
}
func (l *linkEPWithMockedAttach) isAttached() bool {
@@ -287,7 +287,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 0 {
@@ -299,7 +299,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -311,7 +311,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -322,7 +322,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -334,7 +334,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -356,7 +356,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
})
@@ -414,7 +414,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got := fakeNet.PacketCount(localAddrByte); got != want {
@@ -566,7 +566,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
}
if !e.isAttached() {
- t.Fatalf("link endpoint not attached to a network disatcher")
+ t.Fatal("link endpoint not attached to a network dispatcher")
}
})
}
@@ -631,196 +631,240 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
checkNIC(false)
}
-func TestRoutesWithDisabledNIC(t *testing.T) {
- const unspecifiedNIC = 0
- const nicID1 = 1
- const nicID2 = 2
-
+func TestRemoveUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
- ep1 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
}
+}
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
+func TestRemoveNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- ep2 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
}
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
}
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ }
+}
- // Test routes to odd address.
- testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
- testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
- testRoute(t, s, nicID1, addr1, "\x05", addr1)
+func TestRouteWithDownNIC(t *testing.T) {
+ tests := []struct {
+ name string
+ downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ }{
+ {
+ name: "Disabled NIC",
+ downFn: (*stack.Stack).DisableNIC,
+ upFn: (*stack.Stack).EnableNIC,
+ },
+
+ // Once a NIC is removed, it cannot be brought up.
+ {
+ name: "Removed NIC",
+ downFn: (*stack.Stack).RemoveNIC,
+ },
+ }
- // Test routes to even address.
- testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
- testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
- testRoute(t, s, nicID2, addr2, "\x06", addr2)
-
- // Disabling NIC1 should result in no routes to odd addresses. Routes to even
- // addresses should continue to be available as NIC2 is still enabled.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- nic1Dst := tcpip.Address("\x05")
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- nic2Dst := tcpip.Address("\x06")
- testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
- testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
- testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
-
- // Disabling NIC2 should result in no routes to even addresses. No route
- // should be available to any address as routes to odd addresses were made
- // unavailable by disabling NIC1 above.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-
- // Enabling NIC1 should make routes to odd addresses available again. Routes
- // to even addresses should continue to be unavailable as NIC2 is still
- // disabled.
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
- testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
- testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-}
-
-func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
const unspecifiedNIC = 0
const nicID1 = 1
const nicID2 = 2
+ const addr1 = tcpip.Address("\x01")
+ const addr2 = tcpip.Address("\x02")
+ const nic1Dst = tcpip.Address("\x05")
+ const nic2Dst = tcpip.Address("\x06")
+
+ setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
- ep1 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
- ep2 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ return s, ep1, ep2
}
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
+ // Tests that routes through a down NIC are not used when looking up a route
+ // for a destination.
+ t.Run("Find", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, _, _ := setup(t)
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Bringing NIC1 down should result in no routes to odd addresses. Routes to
+ // even addresses should continue to be available as NIC2 is still up.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Bringing NIC2 down should result in no routes to even addresses. No
+ // route should be available to any address as routes to odd addresses
+ // were made unavailable by bringing NIC1 down above.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ if upFn := test.upFn; upFn != nil {
+ // Bringing NIC1 up should make routes to odd addresses available
+ // again. Routes to even addresses should continue to be unavailable
+ // as NIC2 is still down.
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+ }
+ })
}
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
- }
+ })
- nic1Dst := tcpip.Address("\x05")
- r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
- }
- defer r1.Release()
+ // Tests that writing a packet using a Route through a down NIC fails.
+ t.Run("WritePacket", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep1, ep2 := setup(t)
- nic2Dst := tcpip.Address("\x06")
- r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
- }
- defer r2.Release()
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
- // If we failed to get routes r1 or r2, we cannot proceed with the test.
- if t.Failed() {
- t.FailNow()
- }
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
- buf := buffer.View([]byte{1})
- testSend(t, r1, ep1, buf)
- testSend(t, r2, ep2, buf)
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
- // Writes with Routes that use the disabled NIC1 should fail.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testSend(t, r2, ep2, buf)
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
- // Writes with Routes that use the disabled NIC2 should fail.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ // Writes with Routes that use NIC1 after being brought down should fail.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
- // Writes with Routes that use the re-enabled NIC1 should succeed.
- // TODO(b/147015577): Should we instead completely invalidate all Routes that
- // were bound to a disabled NIC at some point?
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ // Writes with Routes that use NIC2 after being brought down should fail.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ if upFn := test.upFn; upFn != nil {
+ // Writes with Routes that use NIC1 after being brought up should
+ // succeed.
+ //
+ // TODO(b/147015577): Should we instead completely invalidate all
+ // Routes that were bound to a NIC that was brought down at some
+ // point?
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ }
+ })
+ }
+ })
}
func TestRoutes(t *testing.T) {
@@ -2213,7 +2257,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
@@ -2240,56 +2284,84 @@ func TestNICStats(t *testing.T) {
}
func TestNICForwarding(t *testing.T) {
- // Create a stack with the fake network protocol, two NICs, each with
- // an address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- s.SetForwarding(true)
+ const nicID1 = 1
+ const nicID2 = 2
+ const dstAddr = tcpip.Address("\x03")
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC #1 failed:", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ tests := []struct {
+ name string
+ headerLen uint16
+ }{
+ {
+ name: "Zero header length",
+ },
+ {
+ name: "Non-zero header length",
+ headerLen: 16,
+ },
}
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC #2 failed:", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ s.SetForwarding(true)
- // Route all packets to address 3 to NIC 2.
- {
- subnet, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
- }
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
+ }
- // Send a packet to address 3.
- buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ ep2 := channelLinkWithHeaderLength{
+ Endpoint: channel.New(10, defaultMTU, ""),
+ headerLength: test.headerLen,
+ }
+ if err := s.CreateNIC(nicID2, &ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
+ }
- if _, ok := ep2.Read(); !ok {
- t.Fatal("Packet not forwarded")
- }
+ // Route all packets to dstAddr to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
+ }
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
+ // Send a packet to dstAddr.
+ buf := buffer.NewView(30)
+ buf[0] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
- if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the link's MaxHeaderLength is honoured.
+ if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
+ t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
+ }
+
+ // Test that forwarding increments Tx stats correctly.
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ })
}
}
@@ -3010,6 +3082,50 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
}
}
+// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6
+// address after leaving its solicited node multicast address does not result in
+// an error.
+func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ }
+
+ // The NIC should have joined addr1's solicited node multicast address.
+ snmc := header.SolicitedNodeAddr(addr1)
+ in, err := s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if !in {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc)
+ }
+
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err)
+ }
+ in, err = s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if in {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc)
+ }
+
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ }
+}
+
func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
const nicID = 1
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 778c0a4d6..c55e3e8bc 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,9 +15,9 @@
package stack
import (
+ "container/heap"
"fmt"
"math/rand"
- "sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -85,7 +85,7 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
epsByNic.mu.RLock()
mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
@@ -116,7 +116,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
epsByNic.mu.RLock()
defer epsByNic.mu.RUnlock()
@@ -141,16 +141,17 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t
epsByNic.mu.Lock()
defer epsByNic.mu.Unlock()
- if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
- // There was already a bind.
- return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ multiPortEp = &multiPortEndpoint{
+ demux: d,
+ netProto: netProto,
+ transProto: transProto,
+ reuse: reusePort,
+ }
+ epsByNic.endpoints[bindToDevice] = multiPortEp
}
- // This is a new binding.
- multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.reuse = reusePort
- epsByNic.endpoints[bindToDevice] = multiPortEp
return multiPortEp.singleRegisterEndpoint(t, reusePort)
}
@@ -183,7 +184,7 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer)
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
@@ -222,6 +223,35 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
+type transportEndpointHeap []TransportEndpoint
+
+var _ heap.Interface = (*transportEndpointHeap)(nil)
+
+func (h *transportEndpointHeap) Len() int {
+ return len(*h)
+}
+
+func (h *transportEndpointHeap) Less(i, j int) bool {
+ return (*h)[i].UniqueID() < (*h)[j].UniqueID()
+}
+
+func (h *transportEndpointHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *transportEndpointHeap) Push(x interface{}) {
+ *h = append(*h, x.(TransportEndpoint))
+}
+
+func (h *transportEndpointHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ old[n-1] = nil
+ *h = old[:n-1]
+ return x
+}
+
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
@@ -237,15 +267,14 @@ type multiPortEndpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- endpointsArr []TransportEndpoint
- endpointsMap map[TransportEndpoint]int
+ endpoints transportEndpointHeap
// reuse indicates if more than one endpoint is allowed.
reuse bool
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ eps := append([]TransportEndpoint(nil), ep.endpoints...)
ep.mu.RUnlock()
return eps
}
@@ -262,8 +291,8 @@ func reciprocalScale(val, n uint32) uint32 {
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpointsArr) == 1 {
- return mpep.endpointsArr[0]
+ if len(mpep.endpoints) == 1 {
+ return mpep.endpoints[0]
}
payload := []byte{
@@ -279,29 +308,26 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
- return mpep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
+ return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket takes ownership of pkt, so each endpoint needs
- // its own copy except for the final one.
- if i == len(ep.endpointsArr)-1 {
- if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
- break
- }
- endpoint.HandlePacket(r, id, pkt)
- break
- }
+ // HandlePacket takes ownership of pkt, so each endpoint needs
+ // its own copy except for the final one.
+ for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
- continue
+ } else {
+ endpoint.HandlePacket(r, id, pkt.Clone())
}
- endpoint.HandlePacket(r, id, pkt.Clone())
+ }
+ if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ } else {
+ endpoint.HandlePacket(r, id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -312,26 +338,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo
ep.mu.Lock()
defer ep.mu.Unlock()
- if len(ep.endpointsArr) > 0 {
+ if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if !ep.reuse || !reusePort {
return tcpip.ErrPortInUse
}
}
- // A new endpoint is added into endpointsArr and its index there is saved in
- // endpointsMap. This will allow us to remove endpoint from the array fast.
- ep.endpointsMap[t] = len(ep.endpointsArr)
- ep.endpointsArr = append(ep.endpointsArr, t)
+ heap.Push(&ep.endpoints, t)
- // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
- // can be restored in the same order.
- sort.Slice(ep.endpointsArr, func(i, j int) bool {
- return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
- })
- for i, e := range ep.endpointsArr {
- ep.endpointsMap[e] = i
- }
return nil
}
@@ -340,21 +355,13 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
- idx, ok := ep.endpointsMap[t]
- if !ok {
- return false
- }
- delete(ep.endpointsMap, t)
- l := len(ep.endpointsArr)
- if l > 1 {
- // The last endpoint in endpointsArr is moved instead of the deleted one.
- lastEp := ep.endpointsArr[l-1]
- ep.endpointsArr[idx] = lastEp
- ep.endpointsMap[lastEp] = idx
- ep.endpointsArr = ep.endpointsArr[0 : l-1]
- return false
+ for i, endpoint := range ep.endpoints {
+ if endpoint == t {
+ heap.Remove(&ep.endpoints, i)
+ break
+ }
}
- return true
+ return len(ep.endpoints) == 0
}
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
@@ -371,17 +378,14 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
- if epsByNic, ok := eps.endpoints[id]; ok {
- // There was already a binding.
- return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
- }
-
- // This is a new binding.
- epsByNic := &endpointsByNic{
- endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ epsByNic, ok := eps.endpoints[id]
+ if !ok {
+ epsByNic = &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
+ }
+ eps.endpoints[id] = epsByNic
}
- eps.endpoints[id] = epsByNic
return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
@@ -396,84 +400,60 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
}
}
-var loopbackSubnet = func() tcpip.Subnet {
- sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
- if err != nil {
- panic(err)
- }
- return sn
-}()
-
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
- eps.mu.RLock()
-
- // Determine which transport endpoint or endpoints to deliver this packet to.
// If the packet is a UDP broadcast or multicast, then find all matching
- // transport endpoints. If the packet is a TCP packet with a non-unicast
- // source or destination address, then do nothing further and instruct
- // the caller to do the same.
- var destEps []*endpointsByNic
- switch protocol {
- case header.UDPProtocolNumber:
- if isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, id)
- break
- }
-
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
+ // transport endpoints.
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ eps.mu.RLock()
+ destEPs := d.findAllEndpointsLocked(eps, id)
+ eps.mu.RUnlock()
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEPs) == 0 {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ return false
}
-
- case header.TCPProtocolNumber:
- if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
- // TCP can only be used to communicate between a single
- // source and a single destination; the addresses must
- // be unicast.
- eps.mu.RUnlock()
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
- return true
+ // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEPs[:len(destEPs)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
}
+ destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ return true
+ }
- fallthrough
-
- default:
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
- }
+ // If the packet is a TCP packet with a non-unicast source or destination
+ // address, then do nothing further and instruct the caller to do the same.
+ if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single source and a
+ // single destination; the addresses must be unicast.
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
}
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, id)
eps.mu.RUnlock()
-
- // Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
- // UDP packet could not be delivered to an unknown destination port.
+ if ep == nil {
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
-
- // HandlePacket takes ownership of pkt, so each endpoint needs its own
- // copy except for the final one.
- for _, ep := range destEps[:len(destEps)-1] {
- ep.handlePacket(r, id, pkt.Clone())
- }
- destEps[len(destEps)-1].handlePacket(r, id, pkt)
-
+ ep.handlePacket(r, id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -497,7 +477,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -519,11 +499,17 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
return true
}
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
- var matchedEPs []*endpointsByNic
+// iterEndpointsLocked yields all endpointsByNic in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the local address.
@@ -531,7 +517,9 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the remote part.
@@ -539,14 +527,26 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
+}
+
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
return matchedEPs
}
@@ -584,10 +584,12 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
// findEndpointLocked returns the endpoint that most closely matches the given
// id.
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
- if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 {
- return matchedEPs[0]
- }
- return nil
+ var matchedEP *endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
}
// registerRawEndpoint registers the given endpoint with the dispatcher such
@@ -601,8 +603,8 @@ func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNum
}
eps.mu.Lock()
- defer eps.mu.Unlock()
eps.rawEndpoints = append(eps.rawEndpoints, ep)
+ eps.mu.Unlock()
return nil
}
@@ -616,13 +618,16 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
}
eps.mu.Lock()
- defer eps.mu.Unlock()
for i, rawEP := range eps.rawEndpoints {
if rawEP == ep {
- eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
- return
+ lastIdx := len(eps.rawEndpoints) - 1
+ eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx]
+ eps.rawEndpoints[lastIdx] = nil
+ eps.rawEndpoints = eps.rawEndpoints[:lastIdx]
+ break
}
}
+ eps.mu.Unlock()
}
func isMulticastOrBroadcast(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 5e9237de9..84311bcc8 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -150,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -167,8 +167,18 @@ func TestTransportDemuxerRegister(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tEP, ok := ep.(stack.TransportEndpoint)
+ if !ok {
+ t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
+ }
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
}
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5d1da2f8b..8ca9ac3cf 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -87,7 +86,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: buffer.View(v).ToVectorisedView(),
}); err != nil {
@@ -214,7 +213,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -231,7 +230,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -242,8 +241,8 @@ func (f *fakeTransportEndpoint) State() uint32 {
func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
- return iptables.IPTables{}, nil
+func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
+ return stack.IPTables{}, nil
}
func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
@@ -288,7 +287,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
return true
}
@@ -368,7 +367,7 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -379,7 +378,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -390,7 +389,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 1 {
@@ -445,7 +444,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -456,7 +455,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -467,7 +466,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 1 {
@@ -622,7 +621,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: req.ToVectorisedView(),
})
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index ac18ec5b1..9ce625c17 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 426da1ee6..613b12ead 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -135,7 +134,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
@@ -291,15 +290,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- toCopy := *to
- to = &toCopy
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- // Find the enpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
+ // Find the endpoint.
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -443,7 +440,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: data.ToVectorisedView(),
TransportHeader: buffer.View(icmpv4),
@@ -473,20 +470,21 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: dataVV,
TransportHeader: buffer.View(icmpv6),
})
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -517,7 +515,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -630,7 +628,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -734,7 +732,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
@@ -796,7 +794,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 113d92901..3c47692b2 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
return true
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index d22de6b26..b989b1209 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 5722815e9..df49d0995 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -76,6 +75,7 @@ type endpoint struct {
sndBufSize int
closed bool
stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
}
// NewEndpoint returns a new packet endpoint.
@@ -99,8 +99,8 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
}
// Abort implements stack.TransportEndpoint.Abort.
-func (e *endpoint) Abort() {
- e.Close()
+func (ep *endpoint) Abort() {
+ ep.Close()
}
// Close implements tcpip.Endpoint.Close.
@@ -125,6 +125,7 @@ func (ep *endpoint) Close() {
}
ep.closed = true
+ ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -132,7 +133,7 @@ func (ep *endpoint) Close() {
func (ep *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+func (ep *endpoint) IPTables() (stack.IPTables, error) {
return ep.stack.IPTables(), nil
}
@@ -216,7 +217,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
// sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
// - packet(7).
- return tcpip.ErrNotSupported
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.bound {
+ return tcpip.ErrAlreadyBound
+ }
+
+ // Unregister endpoint with all the nics.
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ // Bind endpoint to receive packets from specific interface.
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ return err
+ }
+
+ ep.bound = true
+
+ return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
@@ -280,7 +298,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index c9baf4600..2eab09088 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -32,7 +32,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 2ef5fac76..536dafd1e 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -161,7 +160,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
@@ -342,7 +341,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
switch e.NetProto {
case header.IPv4ProtocolNumber:
if !e.associated {
- if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{
+ if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{
Data: buffer.View(payloadBytes).ToVectorisedView(),
}); err != nil {
return 0, nil, err
@@ -350,7 +349,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
break
}
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: buffer.View(payloadBytes).ToVectorisedView(),
}); err != nil {
@@ -574,7 +573,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 272e8f570..7f94f9646 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -32,6 +32,7 @@ go_library(
srcs = [
"accept.go",
"connect.go",
+ "connect_unsafe.go",
"cubic.go",
"cubic_state.go",
"dispatcher.go",
@@ -65,12 +66,10 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
- "//pkg/tmutex",
"//pkg/waiter",
"@com_github_google_btree//:go_default_library",
],
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 13e383ffc..375ca21f6 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -221,7 +221,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
}
// createConnectingEndpoint creates a new endpoint in a connecting state, with
-// the connection parameters given by the arguments.
+// the connection parameters given by the arguments. The endpoint is returned
+// with n.mu held.
func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
@@ -236,27 +237,13 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
n.amss = mssForRoute(&n.route)
+ n.setEndpointState(StateConnecting)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
n.initGSO()
- // Now inherit any socket options that should be inherited from the
- // listening endpoint.
- // In case of Forwarder listenEP will be nil and hence this check.
- if l.listenEP != nil {
- l.listenEP.propagateInheritableOptions(n)
- }
-
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
-
// Create sender and receiver.
//
// The receiver at least temporarily has a zero receive window scale,
@@ -268,11 +255,27 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ n.mu.Lock()
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+ return nil, err
+ }
+
+ n.isRegistered = true
+
return n, nil
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
// and then performs the TCP 3-way handshake.
+//
+// The new endpoint is returned with e.mu held.
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
@@ -288,9 +291,25 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
l.listenEP.mu.Unlock()
+ // Ensure we release any registrations done by the newly
+ // created endpoint.
+ ep.mu.Unlock()
+ ep.Close()
+
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ l.listenEP.propagateInheritableOptionsLocked(ep)
+
deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
@@ -298,6 +317,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
+ ep.mu.Unlock()
ep.Close()
// Wake up any waiters. This is strictly not required normally
// as a socket that was never accepted can't really have any
@@ -311,9 +331,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
}
return nil, err
}
- ep.mu.Lock()
ep.isConnectNotified = true
- ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
@@ -347,30 +365,38 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state, the new endpoint is closed
-// instead.
+// endpoint has transitioned out of the listen state (acceptedChan is nil),
+// the new endpoint is closed instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.EndpointState()
e.pendingAccepted.Add(1)
- defer e.pendingAccepted.Done()
- acceptedChan := e.acceptedChan
e.mu.Unlock()
+ defer e.pendingAccepted.Done()
- if state == StateListen {
- acceptedChan <- n
- e.waiterQueue.Notify(waiter.EventIn)
- } else {
- n.Close()
+ e.acceptMu.Lock()
+ for {
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ n.Close()
+ return
+ }
+ select {
+ case e.acceptedChan <- n:
+ e.acceptMu.Unlock()
+ e.waiterQueue.Notify(waiter.EventIn)
+ return
+ default:
+ e.acceptCond.Wait()
+ }
}
}
-// propagateInheritableOptions propagates any options set on the listening
+// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
-func (e *endpoint) propagateInheritableOptions(n *endpoint) {
- e.mu.Lock()
+//
+// Precondition: e.mu and n.mu must be held.
+func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
- e.mu.Unlock()
}
// handleSynSegment is called in its own goroutine once the listening endpoint
@@ -381,7 +407,11 @@ func (e *endpoint) propagateInheritableOptions(n *endpoint) {
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer decSynRcvdCount()
- defer e.decSynRcvdCount()
+ defer func() {
+ e.mu.Lock()
+ e.decSynRcvdCount()
+ e.mu.Unlock()
+ }()
defer s.decRef()
n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
@@ -398,30 +428,24 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
func (e *endpoint) incSynRcvdCount() bool {
- e.mu.Lock()
- if e.synRcvdCount >= cap(e.acceptedChan) {
- e.mu.Unlock()
- return false
+ e.acceptMu.Lock()
+ canInc := e.synRcvdCount < cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ if canInc {
+ e.synRcvdCount++
}
- e.synRcvdCount++
- e.mu.Unlock()
- return true
+ return canInc
}
func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
e.synRcvdCount--
- e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
- e.mu.Lock()
- if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
- e.mu.Unlock()
- return true
- }
- e.mu.Unlock()
- return false
+ e.acceptMu.Lock()
+ full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ return full
}
// handleListenSegment is called when a listening endpoint receives a segment
@@ -431,7 +455,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil)
+ e.sendTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagRst,
+ seq: s.ackNumber,
+ ack: 0,
+ rcvWnd: 0,
+ }, buffer.VectorisedView{}, nil)
return
}
@@ -479,7 +511,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TSEcr: opts.TSVal,
MSS: mssForRoute(&s.route),
}
- e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.sendSynTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagSyn | header.TCPFlagAck,
+ seq: cookie,
+ ack: s.sequenceNumber + 1,
+ rcvWnd: ctx.rcvWnd,
+ }, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
@@ -558,6 +598,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
// randomly offset when the original SYN-ACK was
@@ -592,14 +636,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
v6only := e.v6only
- e.mu.Unlock()
ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
- e.mu.Lock()
e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
@@ -621,7 +663,10 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
- switch index, _ := s.Fetch(true); index {
+ e.mu.Unlock()
+ index, _ := s.Fetch(true)
+ e.mu.Lock()
+ switch index {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
@@ -634,7 +679,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
s.decRef()
}
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
case wakerForNewSegment:
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 7730e6445..1d245c2c6 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -61,6 +61,9 @@ const (
)
// handshake holds the state used during a TCP 3-way handshake.
+//
+// NOTE: handshake.ep.mu is held during handshake processing. It is released if
+// we are going to block and reacquired when we start processing an event.
type handshake struct {
ep *endpoint
state handshakeState
@@ -209,9 +212,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.mss = opts.MSS
h.sndWndScale = opts.WS
h.deferAccept = deferAccept
- h.ep.mu.Lock()
h.ep.setEndpointState(StateSynRecv)
- h.ep.mu.Unlock()
}
// checkAck checks if the ACK number, if present, of a segment received during
@@ -241,9 +242,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
// was acceptable then signal the user "error: connection reset", drop
// the segment, enter CLOSED state, delete TCB, and return."
- h.ep.mu.Lock()
h.ep.workerCleanup = true
- h.ep.mu.Unlock()
// Although the RFC above calls out ECONNRESET, Linux actually returns
// ECONNREFUSED here so we do as well.
return tcpip.ErrConnectionRefused
@@ -281,9 +280,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if s.flagIsSet(header.TCPFlagAck) {
h.state = handshakeCompleted
- h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
- h.ep.mu.Unlock()
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -293,10 +290,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// but resend our own SYN and wait for it to be acknowledged in the
// SYN-RCVD state.
h.state = handshakeSynRcvd
- h.ep.mu.Lock()
ttl := h.ep.ttl
+ amss := h.ep.amss
h.ep.setEndpointState(StateSynRecv)
- h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
@@ -307,12 +303,20 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
if ttl == 0 {
ttl = s.route.DefaultTTL()
}
- h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
return nil
}
@@ -365,7 +369,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
return nil
}
@@ -394,15 +406,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
}
h.state = handshakeCompleted
- h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
+
// If the segment has data then requeue it for the receiver
// to process it again once main loop is started.
if s.data.Size() > 0 {
s.incRef()
h.ep.enqueueSegment(s)
}
- h.ep.mu.Unlock()
return nil
}
@@ -488,7 +499,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
if n&notifyDrain != 0 {
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
}
}
@@ -553,10 +566,23 @@ func (h *handshake) execute() *tcpip.Error {
synOpts.WS = -1
}
}
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
for h.state != handshakeCompleted {
- switch index, _ := s.Fetch(true); index {
+ h.ep.mu.Unlock()
+ index, _ := s.Fetch(true)
+ h.ep.mu.Lock()
+ switch index {
+
case wakerForResend:
timeOut *= 2
if timeOut > MaxRTO {
@@ -572,12 +598,20 @@ func (h *handshake) execute() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
}
case wakerForNotification:
n := h.ep.fetchNotifications()
- if n&notifyClose != 0 {
+ if (n&notifyClose)|(n&notifyAbort) != 0 {
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -593,7 +627,9 @@ func (h *handshake) execute() *tcpip.Error {
}
}
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
}
case wakerForNewSegment:
@@ -617,17 +653,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
var optionPool = sync.Pool{
New: func() interface{} {
- return make([]byte, maxOptionSize)
+ return &[maxOptionSize]byte{}
},
}
func getOptions() []byte {
- return optionPool.Get().([]byte)
+ return (*optionPool.Get().(*[maxOptionSize]byte))[:]
}
func putOptions(options []byte) {
// Reslice to full capacity.
- optionPool.Put(options[0:cap(options)])
+ optionPool.Put(optionsToArray(options))
}
func makeSynOptions(opts header.TCPSynOptions) []byte {
@@ -683,18 +719,33 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
return options[:offset]
}
-func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
- options := makeSynOptions(opts)
+// tcpFields is a struct to carry different parameters required by the
+// send*TCP variant functions below.
+type tcpFields struct {
+ id stack.TransportEndpointID
+ ttl uint8
+ tos uint8
+ flags byte
+ seq seqnum.Value
+ ack seqnum.Value
+ rcvWnd seqnum.Size
+ opts []byte
+ txHash uint32
+}
+
+func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error {
+ tf.opts = makeSynOptions(opts)
// We ignore SYN send errors and let the callers re-attempt send.
- if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil {
+ if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil {
e.stats.SendErrors.SynSendToNetworkFailed.Increment()
}
- putOptions(options)
+ putOptions(tf.opts)
return nil
}
-func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil {
+func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+ tf.txHash = e.txHash
+ if err := sendTCP(r, tf, data, gso); err != nil {
e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
return err
}
@@ -702,8 +753,8 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu
return nil
}
-func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
- optLen := len(opts)
+func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
+ optLen := len(tf.opts)
hdr := &pkt.Header
packetSize := pkt.DataSize
off := pkt.DataOffset
@@ -711,15 +762,15 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.Packet
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
pkt.TransportHeader = buffer.View(tcp)
tcp.Encode(&header.TCPFields{
- SrcPort: id.LocalPort,
- DstPort: id.RemotePort,
- SeqNum: uint32(seq),
- AckNum: uint32(ack),
+ SrcPort: tf.id.LocalPort,
+ DstPort: tf.id.RemotePort,
+ SeqNum: uint32(tf.seq),
+ AckNum: uint32(tf.ack),
DataOffset: uint8(header.TCPMinimumSize + optLen),
- Flags: flags,
- WindowSize: uint16(rcvWnd),
+ Flags: tf.flags,
+ WindowSize: uint16(tf.rcvWnd),
})
- copy(tcp[header.TCPMinimumSize:], opts)
+ copy(tcp[header.TCPMinimumSize:], tf.opts)
length := uint16(hdr.UsedLength() + packetSize)
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
@@ -734,13 +785,12 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.Packet
xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
-
}
-func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- optLen := len(opts)
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
}
mss := int(gso.MSS)
@@ -749,7 +799,7 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect
// Allocate one big slice for all the headers.
hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
buf := make([]byte, n*hdrSize)
- pkts := make([]tcpip.PacketBuffer, n)
+ pkts := make([]stack.PacketBuffer, n)
for i := range pkts {
pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
}
@@ -765,14 +815,15 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect
pkts[i].DataOffset = off
pkts[i].DataSize = packetSize
pkts[i].Data = data
- buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso)
+ pkts[i].Hash = tf.txHash
+ buildTCPHdr(r, tf, &pkts[i], gso)
off += packetSize
- seq = seq.Add(seqnum.Size(packetSize))
+ tf.seq = tf.seq.Add(seqnum.Size(packetSize))
}
- if ttl == 0 {
- ttl = r.DefaultTTL()
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
}
- sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos})
if err != nil {
r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
}
@@ -782,33 +833,34 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- optLen := len(opts)
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
- return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ return sendTCPBatch(r, tf, data, gso)
}
- pkt := tcpip.PacketBuffer{
+ pkt := stack.PacketBuffer{
Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
DataOffset: 0,
DataSize: data.Size(),
Data: data,
+ Hash: tf.txHash,
}
- buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso)
+ buildTCPHdr(r, tf, &pkt, gso)
- if ttl == 0 {
- ttl = r.DefaultTTL()
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil {
+ if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
r.Stats().TCP.SegmentsSent.Increment()
- if (flags & header.TCPFlagRst) != 0 {
+ if (tf.flags & header.TCPFlagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
return nil
@@ -860,7 +912,16 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso)
+ err := e.sendTCP(&e.route, tcpFields{
+ id: e.ID,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: rcvWnd,
+ opts: options,
+ }, data, e.gso)
putOptions(options)
return err
}
@@ -875,7 +936,6 @@ func (e *endpoint) handleWrite() *tcpip.Error {
first := e.sndQueue.Front()
if first != nil {
e.snd.writeList.PushBackList(&e.sndQueue)
- e.snd.sndNxtList.UpdateForward(e.sndBufInQueue)
e.sndBufInQueue = 0
}
@@ -1009,7 +1069,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
- e.mu.Lock()
switch e.EndpointState() {
// In case of a RST in CLOSE-WAIT linux moves
// the socket to closed state with an error set
@@ -1033,11 +1092,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
case StateCloseWait:
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
- e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
- e.mu.Unlock()
// RFC 793, page 37 states that "in all states
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
@@ -1150,9 +1207,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
- e.mu.RLock()
state := e.state
- e.mu.RUnlock()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
@@ -1175,9 +1230,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
- e.mu.RLock()
userTimeout := e.userTimeout
- e.mu.RUnlock()
e.keepalive.Lock()
if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
@@ -1241,6 +1294,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// goroutine and is responsible for sending segments and handling received
// segments.
func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
+ e.mu.Lock()
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1262,7 +1316,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.mu.Unlock()
- e.workMu.Unlock()
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1273,16 +1326,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// completion.
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
- e.mu.Lock()
h.ep.setEndpointState(StateSynSent)
- e.mu.Unlock()
if err := h.execute(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
- e.mu.Lock()
e.setEndpointState(StateError)
e.HardError = err
@@ -1295,9 +1345,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- e.mu.Lock()
drained := e.drainDone != nil
- e.mu.Unlock()
if drained {
close(e.drainDone)
<-e.undrain
@@ -1323,10 +1371,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// This means the socket is being closed due
// to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
- e.mu.Lock()
e.transitionToStateCloseLocked()
e.workerCleanup = true
- e.mu.Unlock()
return nil
},
},
@@ -1381,7 +1427,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if n&notifyClose != 0 && closeTimer == nil {
- e.mu.Lock()
if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
@@ -1390,7 +1435,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
})
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
- e.mu.Unlock()
}
if n&notifyKeepaliveChanged != 0 {
@@ -1410,7 +1454,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
}
@@ -1453,7 +1499,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.rcvListMu.Unlock()
- e.mu.Lock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
@@ -1461,7 +1506,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
- e.mu.Lock()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1473,16 +1517,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
loop:
for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
e.mu.Unlock()
- e.workMu.Unlock()
v, _ := s.Fetch(true)
- e.workMu.Lock()
+ e.mu.Lock()
- // We need to double check here because the notification maybe
+ // We need to double check here because the notification may be
// stale by the time we got around to processing it.
- //
- // NOTE: since we now hold the workMu the processors cannot
- // change the state of the endpoint so it's safe to proceed
- // after this check.
switch e.EndpointState() {
case StateError:
// If the endpoint has already transitioned to an ERROR
@@ -1495,21 +1534,17 @@ loop:
case StateTimeWait:
fallthrough
case StateClose:
- e.mu.Lock()
break loop
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
return nil
}
- e.mu.Lock()
}
}
- state := e.EndpointState()
- e.mu.Unlock()
var reuseTW func()
- if state == StateTimeWait {
+ if e.EndpointState() == StateTimeWait {
// Disable close timer as we now entering real TIME_WAIT.
if closeTimer != nil {
closeTimer.Stop()
@@ -1519,14 +1554,11 @@ loop:
s.Done()
// Wake up any waiters before we enter TIME_WAIT.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
- e.mu.Lock()
e.workerCleanup = true
- e.mu.Unlock()
reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
- e.mu.Lock()
if e.EndpointState() != StateError {
e.transitionToStateCloseLocked()
}
@@ -1632,6 +1664,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const timeWaitDone = 3
s := sleep.Sleeper{}
+ defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
@@ -1641,9 +1674,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
defer timeWaitTimer.Stop()
for {
- e.workMu.Unlock()
+ e.mu.Unlock()
v, _ := s.Fetch(true)
- e.workMu.Lock()
+ e.mu.Lock()
switch v {
case newSegment:
extendTimeWait, reuseTW := e.handleTimeWaitSegments()
@@ -1666,7 +1699,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
e.handleTimeWaitSegments()
}
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
return nil
}
case timeWaitDone:
diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go
new file mode 100644
index 000000000..cfc304616
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect_unsafe.go
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// optionsToArray converts a slice of capacity >-= maxOptionSize to an array.
+//
+// optionsToArray panics if the capacity of options is smaller than
+// maxOptionSize.
+func optionsToArray(options []byte) *[maxOptionSize]byte {
+ // Reslice to full capacity.
+ options = options[0:maxOptionSize]
+ return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data))
+}
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index d792b07d6..6062ca916 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -128,7 +127,7 @@ func (p *processor) handleSegments() {
continue
}
- if !ep.workMu.TryLock() {
+ if !ep.mu.TryLock() {
ep.newSegmentWaker.Assert()
continue
}
@@ -138,12 +137,10 @@ func (p *processor) handleSegments() {
if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
// Send any active resets if required.
if err != nil {
- ep.mu.Lock()
ep.resetConnectionLocked(err)
- ep.mu.Unlock()
}
ep.notifyProtocolGoroutine(notifyTickleWorker)
- ep.workMu.Unlock()
+ ep.mu.Unlock()
continue
}
@@ -151,7 +148,7 @@ func (p *processor) handleSegments() {
p.epQ.enqueue(ep)
}
- ep.workMu.Unlock()
+ ep.mu.Unlock()
}
}
}
@@ -189,7 +186,7 @@ func (d *dispatcher) wait() {
}
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
if !s.parse() {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f1ad19dac..1ebee0cfe 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"fmt"
"math"
+ "runtime"
"strings"
"sync/atomic"
"time"
@@ -29,11 +30,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tmutex"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -283,6 +282,38 @@ func (*EndpointInfo) IsEndpointInfo() {}
// synchronized. The protocol implementation, however, runs in a single
// goroutine.
//
+// Each endpoint has a few mutexes:
+//
+// e.mu -> Primary mutex for an endpoint must be held for all operations except
+// in e.Readiness where acquiring it will result in a deadlock in epoll
+// implementation.
+//
+// The following three mutexes can be acquired independent of e.mu but if
+// acquired with e.mu then e.mu must be acquired first.
+//
+// e.acceptMu -> protects acceptedChan.
+// e.rcvListMu -> Protects the rcvList and associated fields.
+// e.sndBufMu -> Protects the sndQueue and associated fields.
+// e.lastErrorMu -> Protects the lastError field.
+//
+// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different
+// based on the context in which the lock is acquired. In the syscall context
+// e.LockUser/e.UnlockUser should be used and when doing background processing
+// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below
+// in brief.
+//
+// The reason for this locking behaviour is to avoid wakeups to handle packets.
+// In cases where the endpoint is already locked the background processor can
+// queue the packet up and go its merry way and the lock owner will eventually
+// process the backlog when releasing the lock. Similarly when acquiring the
+// lock from say a syscall goroutine we can implement a bit of spinning if we
+// know that the lock is not held by another syscall goroutine. Background
+// processors should never hold the lock for long and we can avoid an expensive
+// sleep/wakeup by spinning for a shortwhile.
+//
+// For more details please see the detailed documentation on
+// e.LockUser/e.UnlockUser methods.
+//
// +stateify savable
type endpoint struct {
EndpointInfo
@@ -299,12 +330,6 @@ type endpoint struct {
// Precondition: epQueue.mu must be held to read/write this field..
pendingProcessing bool `state:"nosave"`
- // workMu is used to arbitrate which goroutine may perform protocol
- // work. Only the main protocol goroutine is expected to call Lock() on
- // it, but other goroutines (e.g., send) may call TryLock() to eagerly
- // perform work without having to wait for the main one to wake up.
- workMu tmutex.Mutex `state:"nosave"`
-
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
@@ -330,15 +355,11 @@ type endpoint struct {
rcvBufSize int
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
- // zeroWindow indicates that the window was closed due to receive buffer
- // space being filled up. This is set by the worker goroutine before
- // moving a segment to the rcvList. This setting is cleared by the
- // endpoint when a Read() call reads enough data for the new window to
- // be non-zero.
- zeroWindow bool
- // The following fields are protected by the mutex.
- mu sync.RWMutex `state:"nosave"`
+ // mu protects all endpoint fields unless documented otherwise. mu must
+ // be acquired before interacting with the endpoint fields.
+ mu sync.Mutex `state:"nosave"`
+ ownedByUser uint32
// state must be read/set using the EndpointState()/setEndpointState() methods.
state EndpointState `state:".(EndpointState)"`
@@ -513,6 +534,23 @@ type endpoint struct {
// to the acceptedChan below terminate before we close acceptedChan.
pendingAccepted sync.WaitGroup `state:"nosave"`
+ // acceptMu protects acceptedChan.
+ acceptMu sync.Mutex `state:"nosave"`
+
+ // acceptCond is a condition variable that can be used to block on when
+ // acceptedChan is full and an endpoint is ready to be delivered.
+ //
+ // This condition variable is required because just blocking on sending
+ // to acceptedChan does not work in cases where endpoint.Listen is
+ // called twice with different backlog values. In such cases the channel
+ // is closed and a new one created. Any pending goroutines blocking on
+ // the write to the channel will panic.
+ //
+ // We use this condition variable to block/unblock goroutines which
+ // tried to deliver an endpoint but couldn't because accept backlog was
+ // full ( See: endpoint.deliverAccepted ).
+ acceptCond *sync.Cond `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -561,6 +599,10 @@ type endpoint struct {
// endpoint and at this point the endpoint is only around
// to complete the TCP shutdown.
closed bool
+
+ // txHash is the transport layer hash to be set on outbound packets
+ // emitted by this endpoint.
+ txHash uint32
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -583,14 +625,93 @@ func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
return maxMSS
}
+// LockUser tries to lock e.mu and if it fails it will check if the lock is held
+// by another syscall goroutine. If yes, then it will goto sleep waiting for the
+// lock to be released, if not then it will spin till it acquires the lock or
+// another syscall goroutine acquires it in which case it will goto sleep as
+// described above.
+//
+// The assumption behind spinning here being that background packet processing
+// should not be holding the lock for long and spinning reduces latency as we
+// avoid an expensive sleep/wakeup of of the syscall goroutine).
+func (e *endpoint) LockUser() {
+ for {
+ // Try first if the sock is locked then check if it's owned
+ // by another user goroutine if not then we spin, otherwise
+ // we just goto sleep on the Lock() and wait.
+ if !e.mu.TryLock() {
+ // If socket is owned by the user then just goto sleep
+ // as the lock could be held for a reasonably long time.
+ if atomic.LoadUint32(&e.ownedByUser) == 1 {
+ e.mu.Lock()
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+ // Spin but yield the processor since the lower half
+ // should yield the lock soon.
+ runtime.Gosched()
+ continue
+ }
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+}
+
+// UnlockUser will check if there are any segments already queued for processing
+// and process any such segments before unlocking e.mu. This is required because
+// we when packets arrive and endpoint lock is already held then such packets
+// are queued up to be processed. If the lock is held by the endpoint goroutine
+// then it will process these packets but if the lock is instead held by the
+// syscall goroutine then we can have the syscall goroutine process the backlog
+// before unlocking.
+//
+// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the
+// endpoint. It's also required eventually when we get rid of the endpoint
+// protocol goroutine altogether.
+//
+// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
+func (e *endpoint) UnlockUser() {
+ // Lock segment queue before checking so that we avoid a race where
+ // segments can be queued between the time we check if queue is empty
+ // and actually unlock the endpoint mutex.
+ for {
+ e.segmentQueue.mu.Lock()
+ if e.segmentQueue.emptyLocked() {
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ e.segmentQueue.mu.Unlock()
+ return
+ }
+ e.segmentQueue.mu.Unlock()
+
+ switch e.EndpointState() {
+ case StateEstablished:
+ if err := e.handleSegments(true /* fastPath */); err != nil {
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ default:
+ // Since we are waking the endpoint goroutine here just unlock
+ // and let it process the queued segments.
+ e.newSegmentWaker.Assert()
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ return
+ }
+ }
+}
+
// StopWork halts packet processing. Only to be used in tests.
func (e *endpoint) StopWork() {
- e.workMu.Lock()
+ e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
func (e *endpoint) ResumeWork() {
- e.workMu.Unlock()
+ e.mu.Unlock()
}
// setEndpointState updates the state of the endpoint to state atomically. This
@@ -672,6 +793,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
count: 9,
},
uniqueID: s.UniqueID(),
+ txHash: s.Rand().Uint32(),
}
var ss SendBufferSizeOption
@@ -709,9 +831,8 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
- e.workMu.Lock()
e.tsOffset = timeStampOffset()
+ e.acceptCond = sync.NewCond(&e.acceptMu)
return e
}
@@ -721,9 +842,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result := waiter.EventMask(0)
- e.mu.RLock()
- defer e.mu.RUnlock()
-
switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
@@ -735,9 +853,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
case StateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
+ e.acceptMu.Lock()
if len(e.acceptedChan) > 0 {
result |= waiter.EventIn
}
+ e.acceptMu.Unlock()
}
}
if e.EndpointState().connected() {
@@ -798,7 +918,21 @@ func (e *endpoint) Abort() {
// If the endpoint disconnected after the check, nothing needs to be
// done, so sending a notification which will potentially be ignored is
// fine.
- if e.EndpointState().connected() {
+ //
+ // If the endpoint connecting finishes after the check, the endpoint
+ // is either in a connected state (where we would notifyAbort anyway),
+ // SYN-RECV (where we would also notifyAbort anyway), or in an error
+ // state where nothing is required and the notification can be safely
+ // ignored.
+ //
+ // Endpoints where a Close during connecting or SYN-RECV state would be
+ // problematic are set to state connecting before being registered (and
+ // thus possible to be Aborted). They are never available in initial
+ // state.
+ //
+ // Endpoints transitioning from initial to connecting state may be
+ // safely either closed or sent notifyAbort.
+ if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() {
e.notifyProtocolGoroutine(notifyAbort)
return
}
@@ -809,25 +943,22 @@ func (e *endpoint) Abort() {
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
- e.mu.Lock()
- closed := e.closed
- e.mu.Unlock()
- if closed {
+ e.LockUser()
+ defer e.UnlockUser()
+ if e.closed {
return
}
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
- e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
- e.closeNoShutdown()
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdownLocked()
}
// closeNoShutdown closes the endpoint without doing a full shutdown. This is
// used when a connection needs to be aborted with a RST and we want to skip
// a full 4 way TCP shutdown.
-func (e *endpoint) closeNoShutdown() {
- e.mu.Lock()
-
+func (e *endpoint) closeNoShutdownLocked() {
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
@@ -848,7 +979,6 @@ func (e *endpoint) closeNoShutdown() {
e.closed = true
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
- tcpip.AddDanglingEndpoint(e)
switch e.EndpointState() {
// Sockets in StateSynRecv state(passive connections) are closed when
// the handshake fails or if the listening socket is closed while
@@ -862,50 +992,38 @@ func (e *endpoint) closeNoShutdown() {
// do nothing.
default:
e.workerCleanup = true
+ tcpip.AddDanglingEndpoint(e)
+ // Worker will remove the dangling endpoint when the endpoint
+ // goroutine terminates.
e.notifyProtocolGoroutine(notifyClose)
}
-
- e.mu.Unlock()
}
// closePendingAcceptableConnections closes all connections that have completed
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
- done := make(chan struct{})
- // Spin a goroutine up as ranging on e.acceptedChan will just block when
- // there are no more connections in the channel. Using a non-blocking
- // select does not work as it can potentially select the default case
- // even when there are pending writes but that are not yet written to
- // the channel.
- go func() {
- defer close(done)
- for n := range e.acceptedChan {
- n.notifyProtocolGoroutine(notifyReset)
- // close all connections that have completed but
- // not accepted by the application.
- n.Close()
- }
- }()
- // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
- // endpoints which have completed handshake but are not yet written to
- // the e.acceptedChan. We wait here till the goroutine above can drain
- // all such connections from e.acceptedChan.
- e.pendingAccepted.Wait()
+ e.acceptMu.Lock()
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ return
+ }
+
close(e.acceptedChan)
- <-done
e.acceptedChan = nil
+ e.acceptCond.Broadcast()
+ e.acceptMu.Unlock()
+
+ // Wait for all pending endpoints to close.
+ e.pendingAccepted.Wait()
}
// cleanupLocked frees all resources associated with the endpoint. It is called
// after Close() is called and the worker goroutine (if any) is done with its
// work.
func (e *endpoint) cleanupLocked() {
-
// Close all endpoints that might have been accepted by TCP but not by
// the client.
- if e.acceptedChan != nil {
- e.closePendingAcceptableConnectionsLocked()
- }
+ e.closePendingAcceptableConnectionsLocked()
e.workerCleanup = false
@@ -945,6 +1063,9 @@ func (e *endpoint) initialReceiveWindow() int {
// ModerateRecvBuf adjusts the receive buffer and the advertised window
// based on the number of bytes copied to user space.
func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.LockUser()
+ defer e.UnlockUser()
+
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
@@ -994,7 +1115,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvBufSize = rcvWnd
availAfter := e.receiveBufferAvailableLocked()
mask := uint32(notifyReceiveWindowChanged)
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.notifyProtocolGoroutine(mask)
@@ -1012,13 +1133,13 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
}
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
+ e.LockUser()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
@@ -1028,7 +1149,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.mu.RUnlock()
+ e.UnlockUser()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
@@ -1038,8 +1159,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
-
- e.mu.RUnlock()
+ e.UnlockUser()
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
@@ -1071,7 +1191,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1112,13 +1232,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
- e.mu.RLock()
+ e.LockUser()
e.sndBufMu.Lock()
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
@@ -1130,113 +1250,68 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// are copying data in.
if !opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
// Fetch data.
v, perr := p.Payload(avail)
if perr != nil || len(v) == 0 {
- if opts.Atomic { // See above.
+ // Note that perr may be nil if len(v) == 0.
+ if opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
- // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- if opts.Atomic {
+ queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
// Add data to the send queue.
s := newSegmentFromView(&e.route, e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
- }
-
- if e.workMu.TryLock() {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
-
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
-
- }
// Do the work inline.
e.handleWrite()
- e.workMu.Unlock()
- } else {
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ e.UnlockUser()
+ return int64(len(v)), nil, nil
+ }
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
+ if opts.Atomic {
+ // Locks released in queueAndSend()
+ return queueAndSend()
+ }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ e.LockUser()
+ e.sndBufMu.Lock()
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
- }
- // Let the protocol goroutine do the work.
- e.sndWaker.Assert()
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
- return int64(len(v)), nil, nil
+ // Locks released in queueAndSend()
+ return queueAndSend()
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
@@ -1289,9 +1364,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// windowCrossedACKThreshold checks if the receive window to be announced now
-// would be under aMSS or under half receive buffer, whichever smaller. This is
-// useful as a receive side silly window syndrome prevention mechanism. If
+// windowCrossedACKThresholdLocked checks if the receive window to be announced
+// now would be under aMSS or under half receive buffer, whichever smaller. This
+// is useful as a receive side silly window syndrome prevention mechanism. If
// window grows to reasonable value, we should send ACK to the sender to inform
// the rx space is now large. We also want ensure a series of small read()'s
// won't trigger a flood of spurious tiny ACK's.
@@ -1302,7 +1377,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// crossed will be true if the window size crossed the ACK threshold.
// above will be true if the new window is >= ACK threshold and false
// otherwise.
-func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
+//
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
newAvail := e.receiveBufferAvailableLocked()
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
@@ -1325,6 +1402,9 @@ func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, abo
// SetSockOptBool sets a socket option.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+
switch opt {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
@@ -1332,9 +1412,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
// We only allow this to be set when we're in the initial state.
if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
@@ -1365,6 +1442,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
mask := uint32(notifyReceiveWindowChanged)
+ e.LockUser()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@@ -1391,11 +1469,12 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Immediately send an ACK to uncork the sender silly window
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
- e.rcvListMu.Unlock()
+ e.rcvListMu.Unlock()
+ e.UnlockUser()
e.notifyProtocolGoroutine(mask)
return nil
@@ -1451,15 +1530,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.ReuseAddressOption:
- e.mu.Lock()
+ e.LockUser()
e.reuseAddr = v != 0
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.ReusePortOption:
- e.mu.Lock()
+ e.LockUser()
e.reusePort = v != 0
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.BindToDeviceOption:
@@ -1467,9 +1546,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
- e.mu.Lock()
+ e.LockUser()
e.bindToDevice = id
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.QuickAckOption:
@@ -1485,16 +1564,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
return tcpip.ErrInvalidOptionValue
}
- e.mu.Lock()
+ e.LockUser()
e.userMSS = uint16(userMSS)
- e.mu.Unlock()
+ e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
case tcpip.TTLOption:
- e.mu.Lock()
+ e.LockUser()
e.ttl = uint8(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.KeepaliveEnabledOption:
@@ -1526,15 +1605,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.TCPUserTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
e.userTimeout = time.Duration(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.BroadcastOption:
- e.mu.Lock()
+ e.LockUser()
e.broadcast = v != 0
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.CongestionControlOption:
@@ -1548,22 +1627,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
availCC := strings.Split(string(avail), " ")
for _, cc := range availCC {
if v == tcpip.CongestionControlOption(cc) {
- // Acquire the work mutex as we may need to
- // reinitialize the congestion control state.
- e.mu.Lock()
+ e.LockUser()
state := e.EndpointState()
e.cc = v
- e.mu.Unlock()
switch state {
case StateEstablished:
- e.workMu.Lock()
- e.mu.Lock()
if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
- e.mu.Unlock()
- e.workMu.Unlock()
}
+ e.UnlockUser()
return nil
}
}
@@ -1573,23 +1646,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNoSuchFile
case tcpip.IPv4TOSOption:
- e.mu.Lock()
+ e.LockUser()
// TODO(gvisor.dev/issue/995): ECN is not currently supported,
// ignore the bits for now.
e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
+ e.LockUser()
// TODO(gvisor.dev/issue/995): ECN is not currently supported,
// ignore the bits for now.
e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.TCPLingerTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
if v < 0 {
// Same as effectively disabling TCPLinger timeout.
v = 0
@@ -1607,16 +1680,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
v = stkTCPLingerTimeout
}
e.tcpLingerTimeout = time.Duration(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case tcpip.TCPDeferAcceptOption:
- e.mu.Lock()
+ e.LockUser()
if time.Duration(v) > MaxRTO {
v = tcpip.TCPDeferAcceptOption(MaxRTO)
}
e.deferAccept = time.Duration(v)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
default:
@@ -1626,8 +1699,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// readyReceiveSize returns the number of bytes ready to be received.
func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint cannot be in listen state.
if e.EndpointState() == StateListen {
@@ -1649,9 +1722,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return false, tcpip.ErrUnknownProtocolOption
}
- e.mu.Lock()
+ e.LockUser()
v := e.v6only
- e.mu.Unlock()
+ e.UnlockUser()
return v, nil
}
@@ -1715,9 +1788,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.ReuseAddressOption:
- e.mu.RLock()
+ e.LockUser()
v := e.reuseAddr
- e.mu.RUnlock()
+ e.UnlockUser()
*o = 0
if v {
@@ -1726,9 +1799,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.ReusePortOption:
- e.mu.RLock()
+ e.LockUser()
v := e.reusePort
- e.mu.RUnlock()
+ e.UnlockUser()
*o = 0
if v {
@@ -1737,9 +1810,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.BindToDeviceOption:
- e.mu.RLock()
+ e.LockUser()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
+ e.UnlockUser()
return nil
case *tcpip.QuickAckOption:
@@ -1750,16 +1823,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.TTLOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TTLOption(e.ttl)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
- e.mu.RLock()
+ e.LockUser()
snd := e.snd
- e.mu.RUnlock()
+ e.UnlockUser()
if snd != nil {
snd.rtt.Lock()
o.RTT = snd.rtt.srtt
@@ -1798,9 +1871,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.TCPUserTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.OutOfBandInlineOption:
@@ -1809,9 +1882,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.BroadcastOption:
- e.mu.Lock()
+ e.LockUser()
v := e.broadcast
- e.mu.Unlock()
+ e.UnlockUser()
*o = 0
if v {
@@ -1820,33 +1893,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
case *tcpip.CongestionControlOption:
- e.mu.Lock()
+ e.LockUser()
*o = e.cc
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.IPv4TOSOption:
- e.mu.RLock()
+ e.LockUser()
*o = tcpip.IPv4TOSOption(e.sendTOS)
- e.mu.RUnlock()
+ e.UnlockUser()
return nil
case *tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
+ e.LockUser()
*o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.mu.RUnlock()
+ e.UnlockUser()
return nil
case *tcpip.TCPLingerTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
case *tcpip.TCPDeferAcceptOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
- e.mu.Unlock()
+ e.UnlockUser()
return nil
default:
@@ -1854,13 +1927,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -1885,12 +1959,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// yet accepted by the app, they are restored without running the main goroutine
// here.
func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
connectingAddr := addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2055,9 +2129,13 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.mu.Lock()
+ e.LockUser()
+ defer e.UnlockUser()
+ return e.shutdownLocked(flags)
+}
+
+func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
e.shutdownFlags |= flags
- finQueued := false
switch {
case e.EndpointState().connected():
// Close for read.
@@ -2071,24 +2149,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.mu.Unlock()
- // Try to send an active reset immediately if the
- // work mutex is available.
- if e.workMu.TryLock() {
- e.mu.Lock()
- // We need to double check here to make
- // sure worker has not transitioned the
- // endpoint out of a connected state
- // before trying to send a reset.
- if e.EndpointState().connected() {
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.notifyProtocolGoroutine(notifyTickleWorker)
- }
- e.mu.Unlock()
- e.workMu.Unlock()
- } else {
- e.notifyProtocolGoroutine(notifyReset)
- }
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ // Wake up worker to terminate loop.
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
}
@@ -2096,43 +2159,36 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// Close for write.
if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
e.sndBufMu.Lock()
-
if e.sndClosed {
// Already closed.
e.sndBufMu.Unlock()
- break
+ if e.EndpointState() == StateTimeWait {
+ return tcpip.ErrNotConnected
+ }
+ return nil
}
// Queue fin segment.
s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
- finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
e.sndBufMu.Unlock()
+ e.handleClose()
}
+ return nil
case e.EndpointState() == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
}
+ return nil
+
default:
- e.mu.Unlock()
return tcpip.ErrNotConnected
}
- e.mu.Unlock()
- if finQueued {
- if e.workMu.TryLock() {
- e.handleClose()
- e.workMu.Unlock()
- } else {
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
- }
- }
- return nil
}
// Listen puts the endpoint in "listen" mode, which allows it to accept
@@ -2147,8 +2203,8 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error {
}
func (e *endpoint) listen(backlog int) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
// Allow the backlog to be adjusted if the endpoint is not shutting down.
// When the endpoint shuts down, it sets workerCleanup to true, and from
@@ -2157,6 +2213,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
if len(e.acceptedChan) > backlog {
return tcpip.ErrInvalidEndpointState
}
@@ -2169,6 +2227,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
for ep := range origChan {
e.acceptedChan <- ep
}
+
+ // Notify any blocked goroutines that they can attempt to
+ // deliver endpoints again.
+ e.acceptCond.Broadcast()
+
return nil
}
@@ -2198,9 +2261,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// The channel may be non-nil when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
+ e.acceptMu.Lock()
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
+ e.acceptMu.Unlock()
+
e.workerRunning = true
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
@@ -2210,7 +2276,6 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop() {
- e.mu.Lock()
e.workerRunning = true
e.mu.Unlock()
wakerInitDone := make(chan struct{})
@@ -2221,8 +2286,8 @@ func (e *endpoint) startAcceptedLoop() {
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// Endpoint must be in listen state before it can accept connections.
if e.EndpointState() != StateListen {
@@ -2230,9 +2295,12 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
// Get the new accepted endpoint.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
var n *endpoint
select {
case n = <-e.acceptedChan:
+ e.acceptCond.Signal()
default:
return nil, nil, tcpip.ErrWouldBlock
}
@@ -2241,8 +2309,8 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
// Bind binds the endpoint to a specific local port and optionally address.
func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
return e.bindLocked(addr)
}
@@ -2256,7 +2324,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
e.BindAddr = addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2320,8 +2388,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// GetLocalAddress returns the address to which the endpoint is bound.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
return tcpip.FullAddress{
Addr: e.ID.LocalAddress,
@@ -2332,8 +2400,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
// GetRemoteAddress returns the address to which the endpoint is connected.
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
@@ -2346,7 +2414,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -2365,7 +2433,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2406,7 +2474,7 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvBufUsed += s.data.Size()
// Increase counter if the receive window falls down below MSS
// or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
e.rcvList.PushBack(s)
@@ -2414,7 +2482,6 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
-
e.waiterQueue.Notify(waiter.EventIn)
}
@@ -2558,9 +2625,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
s.SegTime = time.Now()
// Copy EndpointID.
- e.mu.Lock()
s.ID = stack.TCPEndpointID(e.ID)
- e.mu.Unlock()
// Copy endpoint rcv state.
e.rcvListMu.Lock()
@@ -2690,10 +2755,10 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
+ e.LockUser()
// Make a copy of the endpoint info.
ret := e.EndpointInfo
- e.mu.RUnlock()
+ e.UnlockUser()
return &ret
}
@@ -2708,9 +2773,9 @@ func (e *endpoint) Wait() {
e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
defer e.waiterQueue.EventUnregister(&waitEntry)
for {
- e.mu.Lock()
+ e.LockUser()
running := e.workerRunning
- e.mu.Unlock()
+ e.UnlockUser()
if !running {
break
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 4a46f0ec5..c3c692555 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -162,8 +162,8 @@ func (e *endpoint) loadState(state EndpointState) {
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
- // as the endpoint is still being loaded and the stack reference to increment
- // metrics is not yet initialized.
+ // as the endpoint is still being loaded and the stack reference is not
+ // yet initialized.
atomic.StoreUint32((*uint32)(&e.state), uint32(state))
}
@@ -173,6 +173,9 @@ func (e *endpoint) afterLoad() {
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
e.state = StateInitial
+ // Condition variables and mutexs are not S/R'ed so reinitialize
+ // acceptCond with e.acceptMu.
+ e.acceptCond = sync.NewCond(&e.acceptMu)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -180,7 +183,6 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
state := e.origEndpointState
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index c9ee5bf06..a094471b8 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 73098d904..1377107ca 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -95,7 +95,7 @@ const (
)
type protocol struct {
- mu sync.Mutex
+ mu sync.RWMutex
sackEnabled bool
delayEnabled bool
sendBufferSize SendBufferSizeOption
@@ -140,7 +140,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
p.dispatcher.queuePacket(r, ep, id, pkt)
}
@@ -151,7 +151,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
@@ -191,7 +191,15 @@ func replyWithReset(s *segment) {
flags |= header.TCPFlagAck
ack = s.sequenceNumber.Add(s.logicalLen())
}
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: s.route.DefaultTTL(),
+ tos: stack.DefaultTOS,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: 0,
+ }, buffer.VectorisedView{}, nil /* gso */)
}
// SetOption implements stack.TransportProtocol.SetOption.
@@ -273,57 +281,57 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
- p.mu.Lock()
+ p.mu.RLock()
*v = SACKEnabled(p.sackEnabled)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *DelayEnabled:
- p.mu.Lock()
+ p.mu.RLock()
*v = DelayEnabled(p.delayEnabled)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *SendBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.sendBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *ReceiveBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.recvBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.CongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.CongestionControlOption(p.congestionControl)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.AvailableCongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.ModerateReceiveBufferOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.TCPLingerTimeoutOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 958f03ac1..caf8977b3 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -168,7 +168,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
- r.ep.mu.Lock()
switch r.ep.EndpointState() {
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
@@ -183,7 +182,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateFinWait2:
r.ep.setEndpointState(StateTimeWait)
}
- r.ep.mu.Unlock()
// Flush out any pending segments, except the very first one if
// it happens to be the one we're handling now because the
@@ -195,6 +193,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
for i := first; i < len(r.pendingRcvdSegments); i++ {
r.pendingRcvdSegments[i].decRef()
+ // Note that slice truncation does not allow garbage collection of
+ // truncated items, thus truncated items must be set to nil to avoid
+ // memory leaks.
+ r.pendingRcvdSegments[i] = nil
}
r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
@@ -204,7 +206,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
- r.ep.mu.Lock()
switch r.ep.EndpointState() {
case StateFinWait1:
r.ep.setEndpointState(StateFinWait2)
@@ -218,7 +219,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateLastAck:
r.ep.transitionToStateCloseLocked()
}
- r.ep.mu.Unlock()
}
return true
@@ -332,10 +332,8 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
- r.ep.mu.RLock()
state := r.ep.EndpointState()
closed := r.ep.closed
- r.ep.mu.RUnlock()
if state != StateEstablished {
drop, err := r.handleRcvdSegmentClosing(s, state, closed)
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 1c10da5ca..e6fe7985d 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -18,7 +18,6 @@ import (
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -56,12 +55,12 @@ type segment struct {
options []byte `state:".([]byte)"`
hasNewSACKInfo bool
rcvdTime time.Time `state:".(unixTime)"`
- // xmitTime is the last transmit time of this segment. A zero value
- // indicates that the segment has yet to be transmitted.
- xmitTime time.Time `state:".(unixTime)"`
+ // xmitTime is the last transmit time of this segment.
+ xmitTime time.Time `state:".(unixTime)"`
+ xmitCount uint32
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
index 9fd061d7d..8d3ddce4b 100644
--- a/pkg/tcpip/transport/tcp/segment_heap.go
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -14,21 +14,25 @@
package tcp
+import "container/heap"
+
type segmentHeap []*segment
+var _ heap.Interface = (*segmentHeap)(nil)
+
// Len returns the length of h.
-func (h segmentHeap) Len() int {
- return len(h)
+func (h *segmentHeap) Len() int {
+ return len(*h)
}
// Less determines whether the i-th element of h is less than the j-th element.
-func (h segmentHeap) Less(i, j int) bool {
- return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
+func (h *segmentHeap) Less(i, j int) bool {
+ return (*h)[i].sequenceNumber.LessThan((*h)[j].sequenceNumber)
}
// Swap swaps the i-th and j-th elements of h.
-func (h segmentHeap) Swap(i, j int) {
- h[i], h[j] = h[j], h[i]
+func (h *segmentHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
}
// Push adds x as the last element of h.
@@ -41,6 +45,7 @@ func (h *segmentHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
+ old[n-1] = nil
*h = old[:n-1]
return x
}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index bd20a7ee9..48a257137 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -28,10 +28,16 @@ type segmentQueue struct {
used int
}
+// emptyLocked determines if the queue is empty.
+// Preconditions: q.mu must be held.
+func (q *segmentQueue) emptyLocked() bool {
+ return q.used == 0
+}
+
// empty determines if the queue is empty.
func (q *segmentQueue) empty() bool {
q.mu.Lock()
- r := q.used == 0
+ r := q.emptyLocked()
q.mu.Unlock()
return r
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index b74b61e7d..6b7bac37d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -126,10 +126,6 @@ type sender struct {
// sndNxt is the sequence number of the next segment to be sent.
sndNxt seqnum.Value
- // sndNxtList is the sequence number of the next segment to be added to
- // the send list.
- sndNxtList seqnum.Value
-
// rttMeasureSeqNum is the sequence number being used for the latest RTT
// measurement.
rttMeasureSeqNum seqnum.Value
@@ -229,7 +225,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
- sndNxtList: iss + 1,
rto: 1 * time.Second,
rttMeasureSeqNum: iss + 1,
lastSendTime: time.Now(),
@@ -455,9 +450,7 @@ func (s *sender) retransmitTimerExpired() bool {
// Give up if we've waited more than a minute since the last resend or
// if a user time out is set and we have exceeded the user specified
// timeout since the first retransmission.
- s.ep.mu.RLock()
uto := s.ep.userTimeout
- s.ep.mu.RUnlock()
if s.firstRetransmittedSegXmitTime.IsZero() {
// We store the original xmitTime of the segment that we are
@@ -713,7 +706,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
default:
s.ep.setEndpointState(StateFinWait1)
}
-
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
@@ -1229,7 +1221,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// sendSegment sends the specified segment.
func (s *sender) sendSegment(seg *segment) *tcpip.Error {
- if !seg.xmitTime.IsZero() {
+ if seg.xmitCount > 0 {
s.ep.stack.Stats().TCP.Retransmits.Increment()
s.ep.stats.SendErrors.Retransmits.Increment()
if s.sndCwnd < s.sndSsthresh {
@@ -1237,6 +1229,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error {
}
}
seg.xmitTime = time.Now()
+ seg.xmitCount++
return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 5b2b16afa..ce3df7478 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2236,9 +2236,18 @@ func TestSegmentMerging(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- // Prevent the endpoint from processing packets.
- test.stop(c.EP)
+ // Send tcp.InitialCwnd number of segments to fill up
+ // InitialWindow but don't ACK. That should prevent
+ // anymore packets from going out.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ view := buffer.NewViewFromBytes([]byte{0})
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+ // Now send the segments that should get merged as the congestion
+ // window is full and we won't be able to send any more packets.
var allData []byte
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
@@ -2248,8 +2257,29 @@ func TestSegmentMerging(t *testing.T) {
}
}
- // Let the endpoint process the segments that we just sent.
- test.resume(c.EP)
+ // Check that we get tcp.InitialCwnd packets.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(header.TCPMinimumSize+1),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
+ RcvWnd: 30000,
+ })
// Check that data is received.
b := c.GetPacket()
@@ -2257,7 +2287,7 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+11),
checker.AckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
@@ -2273,7 +2303,7 @@ func TestSegmentMerging(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))),
+ AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
RcvWnd: 30000,
})
})
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 1e9a0dea3..d4f6bc635 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -204,6 +204,7 @@ func (c *Context) Cleanup() {
if c.EP != nil {
c.EP.Close()
}
+ c.Stack().Close()
}
// Stack returns a reference to the stack in the Context.
@@ -306,7 +307,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -362,7 +363,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: s,
})
}
@@ -370,7 +371,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) {
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: c.BuildSegment(payload, h),
})
}
@@ -379,7 +380,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
}
@@ -547,7 +548,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index adc908e24..b5d2d0ba6 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,7 +32,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1c6a600b8..a3372ac58 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -234,7 +233,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
@@ -443,19 +442,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- r, _, err := e.connectRoute(nicID, *to, netProto)
+ r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
return 0, nil, err
}
defer r.Release()
route = &r
- dstPort = to.Port
+ dstPort = dst.Port
}
if route.IsResolutionRequired() {
@@ -566,7 +565,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa)
+ fa, netProto, err := e.checkV4MappedLocked(fa)
if err != nil {
return err
}
@@ -913,7 +912,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{
Header: hdr,
Data: data,
TransportHeader: buffer.View(udp),
@@ -927,13 +926,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -981,10 +981,6 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr)
- if err != nil {
- return err
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -1012,6 +1008,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
r, nicID, err := e.connectRoute(nicID, addr, netProto)
if err != nil {
return err
@@ -1139,7 +1140,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -1258,7 +1259,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.Data.First())
if int(hdr.Length()) > pkt.Data.Size() {
@@ -1325,7 +1326,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 43fb047ed..466bd9381 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
e.stack = s
for _, m := range e.multicastMemberships {
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index fc706ede2..a674ceb68 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
@@ -61,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- pkt tcpip.PacketBuffer
+ pkt stack.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 8df089d22..6e31a9bac 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -66,7 +66,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.Data.First())
if int(hdr.Length()) > pkt.Data.Size() {
@@ -135,7 +135,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -172,7 +172,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 34b7c2360..0905726c1 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -439,7 +439,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -486,7 +486,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 35391030f..7ea5bfade 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -158,6 +158,9 @@ type Config struct {
// DebugLog is the path to log debug information to, if not empty.
DebugLog string
+ // PanicLog is the path to log GO's runtime messages, if not empty.
+ PanicLog string
+
// DebugLogFormat is the log format for debug.
DebugLogFormat string
@@ -269,6 +272,7 @@ func (c *Config) ToFlags() []string {
"--log=" + c.LogFilename,
"--log-format=" + c.LogFormat,
"--debug-log=" + c.DebugLog,
+ "--panic-log=" + c.PanicLog,
"--debug-log-format=" + c.DebugLogFormat,
"--file-access=" + c.FileAccess.String(),
"--overlay=" + strconv.FormatBool(c.Overlay),
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 17e774e0c..8125d5061 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -101,11 +101,14 @@ const (
// Profiling related commands (see pprof.go for more details).
const (
- StartCPUProfile = "Profile.StartCPUProfile"
- StopCPUProfile = "Profile.StopCPUProfile"
- HeapProfile = "Profile.HeapProfile"
- StartTrace = "Profile.StartTrace"
- StopTrace = "Profile.StopTrace"
+ StartCPUProfile = "Profile.StartCPUProfile"
+ StopCPUProfile = "Profile.StopCPUProfile"
+ HeapProfile = "Profile.HeapProfile"
+ GoroutineProfile = "Profile.GoroutineProfile"
+ BlockProfile = "Profile.BlockProfile"
+ MutexProfile = "Profile.MutexProfile"
+ StartTrace = "Profile.StartTrace"
+ StopTrace = "Profile.StopTrace"
)
// Logging related commands (see logging.go for more details).
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index a4627905e..06b9f888a 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -284,12 +284,21 @@ var allowedSyscalls = seccomp.SyscallRules{
{seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
+ unix.SYS_STATX: {},
syscall.SYS_SYNC_FILE_RANGE: {},
syscall.SYS_TGKILL: []seccomp.Rule{
{
seccomp.AllowValue(uint64(os.Getpid())),
},
},
+ syscall.SYS_UTIMENSAT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0), /* null pathname */
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0), /* flags */
+ },
+ },
syscall.SYS_WRITE: {},
// The only user in rawfile.NonBlockingWrite3 always passes iovcnt with
// values 2 or 3. Three iovec-s are passed, when the PACKET_VNET_HDR
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index 0f3da69a0..0938944a6 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -23,6 +23,7 @@ import (
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
@@ -82,8 +83,13 @@ type Boot struct {
// sandbox (e.g. gofer) and sent through this FD.
mountsFD int
- // pidns is set if the sanadbox is in its own pid namespace.
+ // pidns is set if the sandbox is in its own pid namespace.
pidns bool
+
+ // attached is set to true to kill the sandbox process when the parent process
+ // terminates. This flag is set when the command execve's itself because
+ // parent death signal doesn't propagate through execve when uid/gid changes.
+ attached bool
}
// Name implements subcommands.Command.Name.
@@ -118,6 +124,7 @@ func (b *Boot) SetFlags(f *flag.FlagSet) {
f.IntVar(&b.userLogFD, "user-log-fd", 0, "file descriptor to write user logs to. 0 means no logging.")
f.IntVar(&b.startSyncFD, "start-sync-fd", -1, "required FD to used to synchronize sandbox startup")
f.IntVar(&b.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to read list of mounts after they have been resolved (direct paths, no symlinks).")
+ f.BoolVar(&b.attached, "attached", false, "if attached is true, kills the sandbox process when the parent process terminates")
}
// Execute implements subcommands.Command.Execute. It starts a sandbox in a
@@ -133,29 +140,32 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
conf := args[0].(*boot.Config)
+ if b.attached {
+ // Ensure this process is killed after parent process terminates when
+ // attached mode is enabled. In the unfortunate event that the parent
+ // terminates before this point, this process leaks.
+ if err := unix.Prctl(unix.PR_SET_PDEATHSIG, uintptr(unix.SIGKILL), 0, 0, 0); err != nil {
+ Fatalf("error setting parent death signal: %v", err)
+ }
+ }
+
if b.setUpRoot {
if err := setUpChroot(b.pidns); err != nil {
Fatalf("error setting up chroot: %v", err)
}
- if !b.applyCaps {
- // Remove --setup-root arg to call myself.
- var args []string
- for _, arg := range os.Args {
- if !strings.Contains(arg, "setup-root") {
- args = append(args, arg)
- }
- }
- if !conf.Rootless {
- // Note that we've already read the spec from the spec FD, and
- // we will read it again after the exec call. This works
- // because the ReadSpecFromFile function seeks to the beginning
- // of the file before reading.
- if err := callSelfAsNobody(args); err != nil {
- Fatalf("%v", err)
- }
- panic("callSelfAsNobody must never return success")
+ if !b.applyCaps && !conf.Rootless {
+ // Remove --apply-caps arg to call myself. It has already been done.
+ args := prepareArgs(b.attached, "setup-root")
+
+ // Note that we've already read the spec from the spec FD, and
+ // we will read it again after the exec call. This works
+ // because the ReadSpecFromFile function seeks to the beginning
+ // of the file before reading.
+ if err := callSelfAsNobody(args); err != nil {
+ Fatalf("%v", err)
}
+ panic("callSelfAsNobody must never return success")
}
}
@@ -181,13 +191,9 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
caps.Permitted = append(caps.Permitted, c)
}
- // Remove --apply-caps arg to call myself.
- var args []string
- for _, arg := range os.Args {
- if !strings.Contains(arg, "setup-root") && !strings.Contains(arg, "apply-caps") {
- args = append(args, arg)
- }
- }
+ // Remove --apply-caps and --setup-root arg to call myself. Both have
+ // already been done.
+ args := prepareArgs(b.attached, "setup-root", "apply-caps")
// Note that we've already read the spec from the spec FD, and
// we will read it again after the exec call. This works
@@ -258,3 +264,22 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
l.Destroy()
return subcommands.ExitSuccess
}
+
+func prepareArgs(attached bool, exclude ...string) []string {
+ var args []string
+ for _, arg := range os.Args {
+ for _, excl := range exclude {
+ if strings.Contains(arg, excl) {
+ goto skip
+ }
+ }
+ args = append(args, arg)
+ if attached && arg == "boot" {
+ // Strategicaly place "--attached" after the command. This is needed
+ // to ensure the new process is killed when the parent process terminates.
+ args = append(args, "--attached")
+ }
+ skip:
+ }
+ return args
+}
diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go
index b5a0ce17d..189244765 100644
--- a/runsc/cmd/chroot.go
+++ b/runsc/cmd/chroot.go
@@ -50,7 +50,7 @@ func pivotRoot(root string) error {
// new_root, so after umounting the old_root, we will see only
// the new_root in "/".
if err := syscall.PivotRoot(".", "."); err != nil {
- return fmt.Errorf("error changing root filesystem: %v", err)
+ return fmt.Errorf("pivot_root failed, make sure that the root mount has a parent: %v", err)
}
if err := syscall.Unmount(".", syscall.MNT_DETACH); err != nil {
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index 79965460e..b5de2588b 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -32,17 +32,20 @@ import (
// Debug implements subcommands.Command for the "debug" command.
type Debug struct {
- pid int
- stacks bool
- signal int
- profileHeap string
- profileCPU string
- trace string
- strace string
- logLevel string
- logPackets string
- duration time.Duration
- ps bool
+ pid int
+ stacks bool
+ signal int
+ profileHeap string
+ profileCPU string
+ profileGoroutine string
+ profileBlock string
+ profileMutex string
+ trace string
+ strace string
+ logLevel string
+ logPackets string
+ duration time.Duration
+ ps bool
}
// Name implements subcommands.Command.
@@ -66,6 +69,9 @@ func (d *Debug) SetFlags(f *flag.FlagSet) {
f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log")
f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.")
f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.")
+ f.StringVar(&d.profileGoroutine, "profile-goroutine", "", "writes goroutine profile to the given file.")
+ f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.")
+ f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.")
f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles")
f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.")
f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox")
@@ -147,6 +153,42 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
log.Infof("Heap profile written to %q", d.profileHeap)
}
+ if d.profileGoroutine != "" {
+ f, err := os.Create(d.profileGoroutine)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.GoroutineProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Goroutine profile written to %q", d.profileGoroutine)
+ }
+ if d.profileBlock != "" {
+ f, err := os.Create(d.profileBlock)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.BlockProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Block profile written to %q", d.profileBlock)
+ }
+ if d.profileMutex != "" {
+ f, err := os.Create(d.profileMutex)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.MutexProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Mutex profile written to %q", d.profileMutex)
+ }
delay := false
if d.profileCPU != "" {
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 6e06f3c0f..02e5af3d3 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -335,7 +335,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
if err := pivotRoot("/proc"); err != nil {
- Fatalf("faild to change the root file system: %v", err)
+ Fatalf("failed to change the root file system: %v", err)
}
if err := os.Chdir("/"); err != nil {
Fatalf("failed to change working directory")
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index c2518d52b..651615d4c 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -333,13 +333,13 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// file. Writes after a certain point will block unless we drain the
// PTY, so we must continually copy from it.
//
- // We log the output to stdout for debugabilitly, and also to a buffer,
+ // We log the output to stderr for debugabilitly, and also to a buffer,
// since we wait on particular output from bash below. We use a custom
// blockingBuffer which is thread-safe and also blocks on Read calls,
// which makes this a suitable Reader for WaitUntilRead.
ptyBuf := newBlockingBuffer()
tee := io.TeeReader(ptyMaster, ptyBuf)
- go io.Copy(os.Stdout, tee)
+ go io.Copy(os.Stderr, tee)
// Start the container.
if err := c.Start(conf); err != nil {
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 68782c4be..c9839044c 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -17,6 +17,7 @@ package container
import (
"context"
+ "errors"
"fmt"
"io/ioutil"
"os"
@@ -1066,18 +1067,10 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer.
func (c *Container) adjustGoferOOMScoreAdj() error {
- if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil {
- if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
- // Ignore NotExist error because it can be returned when the sandbox
- // exited while OOM score was being adjusted.
- if !os.IsNotExist(err) {
- return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
- }
- log.Warningf("Gofer process (%d) not found setting oom_score_adj", c.GoferPid)
- }
+ if c.GoferPid == 0 || c.Spec.Process.OOMScoreAdj == nil {
+ return nil
}
-
- return nil
+ return setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj)
}
// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox.
@@ -1154,29 +1147,29 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool)
}
// Set the lowest of all containers oom_score_adj to the sandbox.
- if err := setOOMScoreAdj(s.Pid, lowScore); err != nil {
- // Ignore NotExist error because it can be returned when the sandbox
- // exited while OOM score was being adjusted.
- if !os.IsNotExist(err) {
- return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
- }
- log.Warningf("Sandbox process (%d) not found setting oom_score_adj", s.Pid)
- }
-
- return nil
+ return setOOMScoreAdj(s.Pid, lowScore)
}
// setOOMScoreAdj sets oom_score_adj to the given value for the given PID.
// /proc must be available and mounted read-write. scoreAdj should be between
-// -1000 and 1000.
+// -1000 and 1000. It's a noop if the process has already exited.
func setOOMScoreAdj(pid int, scoreAdj int) error {
f, err := os.OpenFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid), os.O_WRONLY, 0644)
if err != nil {
+ // Ignore NotExist errors because it can race with process exit.
+ if os.IsNotExist(err) {
+ log.Warningf("Process (%d) not found setting oom_score_adj", pid)
+ return nil
+ }
return err
}
defer f.Close()
if _, err := f.WriteString(strconv.Itoa(scoreAdj)); err != nil {
- return err
+ if errors.Is(err, syscall.ESRCH) {
+ log.Warningf("Process (%d) exited while setting oom_score_adj", pid)
+ return nil
+ }
+ return fmt.Errorf("setting oom_score_adj to %q: %v", scoreAdj, err)
}
return nil
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index bdd65b498..442e80ac0 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -124,23 +124,6 @@ func procListsEqual(got, want []*control.Process) (bool, error) {
return true, nil
}
-// getAndCheckProcLists is similar to waitForProcessList, but does not wait and retry the
-// test for equality. This is because we already confirmed that exec occurred.
-func getAndCheckProcLists(cont *Container, want []*control.Process) error {
- got, err := cont.Processes()
- if err != nil {
- return fmt.Errorf("error getting process data from container: %v", err)
- }
- equal, err := procListsEqual(got, want)
- if err != nil {
- return err
- }
- if equal {
- return nil
- }
- return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want))
-}
-
func procListToString(pl []*control.Process) string {
strs := make([]string, 0, len(pl))
for _, p := range pl {
@@ -2092,7 +2075,7 @@ func TestOverlayfsStaleRead(t *testing.T) {
defer out.Close()
const want = "foobar"
- cmd := fmt.Sprintf("cat %q && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name())
+ cmd := fmt.Sprintf("cat %q >&2 && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name())
spec := testutil.NewSpecWithArgs("/bin/bash", "-c", cmd)
if err := run(spec, conf); err != nil {
t.Fatalf("Error running container: %v", err)
diff --git a/runsc/main.go b/runsc/main.go
index af73bed97..62e184ec9 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -54,9 +54,11 @@ var (
// Debugging flags.
debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.")
+ panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.")
logPackets = flag.Bool("log-packets", false, "enable network packet logging.")
logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
+ panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.")
debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.")
@@ -206,6 +208,7 @@ func main() {
LogFilename: *logFilename,
LogFormat: *logFormat,
DebugLog: *debugLog,
+ PanicLog: *panicLog,
DebugLogFormat: *debugLogFormat,
FileAccess: fsAccess,
FSGoferHostUDS: *fsGoferHostUDS,
@@ -258,20 +261,6 @@ func main() {
if *debugLogFD > -1 {
f := os.NewFile(uintptr(*debugLogFD), "debug log file")
- // Quick sanity check to make sure no other commands get passed
- // a log fd (they should use log dir instead).
- if subcommand != "boot" && subcommand != "gofer" {
- cmd.Fatalf("flag --debug-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
- }
-
- // If we are the boot process, then we own our stdio FDs and can do what we
- // want with them. Since Docker and Containerd both eat boot's stderr, we
- // dup our stderr to the provided log FD so that panics will appear in the
- // logs, rather than just disappear.
- if err := syscall.Dup3(int(f.Fd()), int(os.Stderr.Fd()), 0); err != nil {
- cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err)
- }
-
e = newEmitter(*debugLogFormat, f)
} else if *debugLog != "" {
@@ -287,6 +276,26 @@ func main() {
e = newEmitter("text", ioutil.Discard)
}
+ if *panicLogFD > -1 || *debugLogFD > -1 {
+ fd := *panicLogFD
+ if fd < 0 {
+ fd = *debugLogFD
+ }
+ // Quick sanity check to make sure no other commands get passed
+ // a log fd (they should use log dir instead).
+ if subcommand != "boot" && subcommand != "gofer" {
+ cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
+ }
+
+ // If we are the boot process, then we own our stdio FDs and can do what we
+ // want with them. Since Docker and Containerd both eat boot's stderr, we
+ // dup our stderr to the provided log FD so that panics will appear in the
+ // logs, rather than just disappear.
+ if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
+ cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
+ }
+ }
+
if *alsoLogToStderr {
e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)}
}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index ec72bdbfd..8de75ae57 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -369,6 +369,24 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
cmd.Args = append(cmd.Args, "--debug-log-fd="+strconv.Itoa(nextFD))
nextFD++
}
+ if conf.PanicLog != "" {
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+
+ panicLogFile, err := specutils.DebugLogFile(conf.PanicLog, "panic", test)
+ if err != nil {
+ return fmt.Errorf("opening debug log file in %q: %v", conf.PanicLog, err)
+ }
+ defer panicLogFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, panicLogFile)
+ cmd.Args = append(cmd.Args, "--panic-log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
cmd.Args = append(cmd.Args, "--panic-signal="+strconv.Itoa(int(syscall.SIGTERM)))
@@ -426,6 +444,12 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
+ // TODO(b/151157106): syscall tests fail by timeout if asyncpreemptoff
+ // isn't set.
+ if conf.Platform == "kvm" {
+ cmd.Env = append(cmd.Env, "GODEBUG=asyncpreemptoff=1")
+ }
+
// The current process' stdio must be passed to the application via the
// --stdio-fds flag. The stdio of the sandbox process itself must not
// be connected to the same FDs, otherwise we risk leaking sandbox
@@ -677,6 +701,13 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
+ if args.Attached {
+ // Kill sandbox if parent process exits in attached mode.
+ cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
+ // Tells boot that any process it creates must have pdeathsig set.
+ cmd.Args = append(cmd.Args, "--attached")
+ }
+
// Add container as the last argument.
cmd.Args = append(cmd.Args, s.ID)
@@ -685,11 +716,6 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
log.Debugf("Donating FD %d: %q", i+3, f.Name())
}
- if args.Attached {
- // Kill sandbox if parent process exits in attached mode.
- cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
- }
-
log.Debugf("Starting sandbox: %s %v", binPath, cmd.Args)
log.Debugf("SysProcAttr: %+v", cmd.SysProcAttr)
if err := specutils.StartInNS(cmd, nss); err != nil {
@@ -972,6 +998,66 @@ func (s *Sandbox) StopCPUProfile() error {
return nil
}
+// GoroutineProfile writes a goroutine profile to the given file.
+func (s *Sandbox) GoroutineProfile(f *os.File) error {
+ log.Debugf("Goroutine profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.GoroutineProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q goroutine profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// BlockProfile writes a block profile to the given file.
+func (s *Sandbox) BlockProfile(f *os.File) error {
+ log.Debugf("Block profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.BlockProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q block profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// MutexProfile writes a mutex profile to the given file.
+func (s *Sandbox) MutexProfile(f *os.File) error {
+ log.Debugf("Mutex profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.MutexProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q mutex profile: %v", s.ID, err)
+ }
+ return nil
+}
+
// StartTrace start trace writing to the given file.
func (s *Sandbox) StartTrace(f *os.File) error {
log.Debugf("Trace start %q", s.ID)
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
index c7dd3051c..60bb7b7ee 100644
--- a/runsc/specutils/namespace.go
+++ b/runsc/specutils/namespace.go
@@ -252,6 +252,9 @@ func MaybeRunAsRoot() error {
},
Credential: &syscall.Credential{Uid: 0, Gid: 0},
GidMappingsEnableSetgroups: false,
+
+ // Make sure child is killed when the parent terminates.
+ Pdeathsig: syscall.SIGKILL,
}
cmd.Env = os.Environ()
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index 92d677e71..51e487715 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -87,18 +87,19 @@ func TestConfig() *boot.Config {
logDir = dir + "/"
}
return &boot.Config{
- Debug: true,
- DebugLog: logDir,
- LogFormat: "text",
- DebugLogFormat: "text",
- AlsoLogToStderr: true,
- LogPackets: true,
- Network: boot.NetworkNone,
- Strace: true,
- Platform: "ptrace",
- FileAccess: boot.FileAccessExclusive,
+ Debug: true,
+ DebugLog: logDir,
+ LogFormat: "text",
+ DebugLogFormat: "text",
+ AlsoLogToStderr: true,
+ LogPackets: true,
+ Network: boot.NetworkNone,
+ Strace: true,
+ Platform: "ptrace",
+ FileAccess: boot.FileAccessExclusive,
+ NumNetworkChannels: 1,
+
TestOnlyAllowRunAsCurrentUserWithoutChroot: true,
- NumNetworkChannels: 1,
}
}
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
index a0317db02..06d44f914 100644
--- a/scripts/benchmark.sh
+++ b/scripts/benchmark.sh
@@ -14,12 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Run in the root of the repo.
-cd "$(dirname "$0")"
+source $(dirname $0)/common.sh
-KEY_PATH=${KEY_PATH:-"${KOKORO_KEYSTORE_DIR}/${KOKORO_SERVICE_ACCOUNT}"}
+# Exporting for subprocesses as GCP APIs and tools check this environmental
+# variable for authentication.
+export GOOGLE_APPLICATION_CREDENTIALS="${KOKORO_KEYSTORE_DIR}/${GCLOUD_CREDENTIALS}"
-gcloud auth activate-service-account --key-file "${KEY_PATH}"
+gcloud auth activate-service-account \
+ --key-file "${GOOGLE_APPLICATION_CREDENTIALS}"
-gcloud compute instances list
+gcloud config set project ${PROJECT}
+gcloud config set compute/zone ${ZONE}
+bazel run //benchmarks:benchmarks -- \
+ --verbose \
+ run-gcp \
+ startup \
+ --runtime=runc \
+ --runtime=runsc \
+ --installers=head
diff --git a/scripts/build.sh b/scripts/build.sh
index 4c042af6c..7c9c99800 100755
--- a/scripts/build.sh
+++ b/scripts/build.sh
@@ -17,7 +17,7 @@
source $(dirname $0)/common.sh
# Install required packages for make_repository.sh et al.
-sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils xz-utils
+apt_install dpkg-sig coreutils apt-utils xz-utils
# Build runsc.
runsc=$(build -c opt //runsc)
diff --git a/scripts/common.sh b/scripts/common.sh
index 3ca699e4a..735a383de 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -84,3 +84,17 @@ function install_runsc() {
# Restart docker to pick up the new runtime configuration.
sudo systemctl restart docker
}
+
+# Installs the given packages. Note that the package names should be verified to
+# be correct, otherwise this may result in a loop that spins until time out.
+function apt_install() {
+ while true; do
+ if (sudo apt-get update && sudo apt-get install -y "$@"); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ return $result
+ fi
+ done
+}
diff --git a/scripts/dev.sh b/scripts/dev.sh
index 6238b4d0b..a9107f33e 100755
--- a/scripts/dev.sh
+++ b/scripts/dev.sh
@@ -66,6 +66,7 @@ if [[ ${REFRESH} -eq 0 ]]; then
else
mkdir -p "$(dirname ${RUNSC_BIN})"
cp -f ${OUTPUT} "${RUNSC_BIN}"
+ chmod a+rx "${RUNSC_BIN}"
echo
echo "Runtime ${RUNTIME} refreshed."
diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh
index 3069d8628..b4a5211a5 100755
--- a/scripts/iptables_tests.sh
+++ b/scripts/iptables_tests.sh
@@ -19,9 +19,12 @@ source $(dirname $0)/common.sh
install_runsc_for_test iptables
# Build the docker image for the test.
-run //test/iptables/runner-image --norun
+run //test/iptables/runner:runner-image --norun
-# TODO(gvisor.dev/issue/170): Also test this on runsc once iptables are better
-# supported
-test //test/iptables:iptables_test "--test_arg=--runtime=runc" \
+test //test/iptables:iptables_test \
+ "--test_arg=--runtime=runc" \
+ "--test_arg=--image=bazel/test/iptables/runner:runner-image"
+
+test //test/iptables:iptables_test \
+ "--test_arg=--runtime=runsc" \
"--test_arg=--image=bazel/test/iptables/runner:runner-image"
diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh
new file mode 100755
index 000000000..027d11e64
--- /dev/null
+++ b/scripts/packetimpact_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+install_runsc_for_test runsc-d
+test_runsc $(bazel query "attr(tags, packetimpact, tests(//test/packetimpact/...))")
diff --git a/scripts/release.sh b/scripts/release.sh
index 091abf87f..e14ba04a7 100755
--- a/scripts/release.sh
+++ b/scripts/release.sh
@@ -25,6 +25,14 @@ if ! [[ -v KOKORO_RELEASE_TAG ]]; then
echo "No KOKORO_RELEASE_TAG provided." >&2
exit 1
fi
+if ! [[ -v KOKORO_RELNOTES ]]; then
+ echo "No KOKORO_RELNOTES provided." >&2
+ exit 1
+fi
+if ! [[ -r "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" ]]; then
+ echo "The file '${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}' is not readable." >&2
+ exit 1
+fi
# Unless an explicit releaser is provided, use the bot e-mail.
declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot}
@@ -46,4 +54,7 @@ EOF
fi
# Run the release tool, which pushes to the origin repository.
-tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}"
+tools/tag_release.sh \
+ "${KOKORO_RELEASE_COMMIT}" \
+ "${KOKORO_RELEASE_TAG}" \
+ "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}"
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
index 4074d2285..594c8e752 100644
--- a/test/e2e/exec_test.go
+++ b/test/e2e/exec_test.go
@@ -240,17 +240,7 @@ func TestExecEnvHasHome(t *testing.T) {
}
d := dockerutil.MakeDocker("exec-env-home-test")
- // We will check that HOME is set for root user, and also for a new
- // non-root user we will create.
- newUID := 1234
- newHome := "/foo/bar"
-
- // Create a new user with a home directory, and then sleep.
- script := fmt.Sprintf(`
- mkdir -p -m 777 %s && \
- adduser foo -D -u %d -h %s && \
- sleep 1000`, newHome, newUID, newHome)
- if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil {
+ if err := d.Run("alpine", "sleep", "1000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
defer d.CleanUp()
@@ -264,7 +254,15 @@ func TestExecEnvHasHome(t *testing.T) {
t.Errorf("wanted exec output to contain %q, got %q", want, got)
}
- // Execute the same as uid 123 and expect newHome.
+ // Create a new user with a home directory.
+ newUID := 1234
+ newHome := "/foo/bar"
+ cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome)
+ if _, err := d.Exec("/bin/sh", "-c", cmd); err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+
+ // Execute the same as the new user and expect newHome.
got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 28064e557..cc4fbbaed 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -175,10 +175,8 @@ func TestCheckpointRestore(t *testing.T) {
t.Fatal(err)
}
- // TODO(b/143498576): Remove after github.com/moby/moby/issues/38963 is fixed.
- time.Sleep(1 * time.Second)
-
- if err := d.Restore("test"); err != nil {
+ // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed.
+ if err := testutil.Poll(func() error { return d.Restore("test") }, 15*time.Second); err != nil {
t.Fatal("docker restore failed:", err)
}
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index b2fb6401a..4ccd4cce7 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -47,6 +47,8 @@ func init() {
RegisterTestCase(FilterInputJumpReturnDrop{})
RegisterTestCase(FilterInputJumpBuiltin{})
RegisterTestCase(FilterInputJumpTwice{})
+ RegisterTestCase(FilterInputDestination{})
+ RegisterTestCase(FilterInputInvertDestination{})
}
// FilterInputDropUDP tests that we can drop UDP traffic.
@@ -106,7 +108,7 @@ func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error {
func (FilterInputDropOnlyUDP) LocalAction(ip net.IP) error {
// Try to establish a TCP connection with the container, which should
// succeed.
- return connectTCP(ip, acceptPort, dropPort, sendloopDuration)
+ return connectTCP(ip, acceptPort, sendloopDuration)
}
// FilterInputDropUDPPort tests that we can drop UDP traffic by port.
@@ -192,8 +194,14 @@ func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ // After the container sets its DROP rule, we shouldn't be able to connect.
+ // However, we may succeed in connecting if this runs before the container
+ // sets the rule. To avoid this race, we retry connecting until
+ // sendloopDuration has elapsed, ignoring whether the connect succeeds. The
+ // test works becuase the container will error if a connection is
+ // established after the rule is set.
+ for start := time.Now(); time.Since(start) < sendloopDuration; {
+ connectTCP(ip, dropPort, sendloopDuration-time.Since(start))
}
return nil
@@ -209,13 +217,14 @@ func (FilterInputDropTCPSrcPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ // Drop anything from an ephemeral port.
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil {
return err
}
// Listen for TCP packets on accept port.
if err := listenTCP(acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort)
}
return nil
@@ -223,8 +232,14 @@ func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection on port %d should not be acceptedi, but got accepted", dropPort)
+ // After the container sets its DROP rule, we shouldn't be able to connect.
+ // However, we may succeed in connecting if this runs before the container
+ // sets the rule. To avoid this race, we retry connecting until
+ // sendloopDuration has elapsed, ignoring whether the connect succeeds. The
+ // test works becuase the container will error if a connection is
+ // established after the rule is set.
+ for start := time.Now(); time.Since(start) < sendloopDuration; {
+ connectTCP(ip, acceptPort, sendloopDuration-time.Since(start))
}
return nil
@@ -595,3 +610,66 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP) error {
func (FilterInputJumpTwice) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
+
+// FilterInputDestination verifies that we can filter packets via `-d
+// <ipaddr>`.
+type FilterInputDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDestination) Name() string {
+ return "FilterInputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDestination) ContainerAction(ip net.IP) error {
+ addrs, err := localAddrs()
+ if err != nil {
+ return err
+ }
+
+ // Make INPUT's default action DROP, then ACCEPT all packets bound for
+ // this machine.
+ rules := [][]string{{"-P", "INPUT", "DROP"}}
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"})
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDestination) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputInvertDestination verifies that we can filter packets via `! -d
+// <ipaddr>`.
+type FilterInputInvertDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputInvertDestination) Name() string {
+ return "FilterInputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInvertDestination) ContainerAction(ip net.IP) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets not bound
+ // for 127.0.0.1.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "!", "-d", localIP, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInvertDestination) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index ee2c49f9a..4582d514c 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -22,9 +22,12 @@ import (
func init() {
RegisterTestCase(FilterOutputDropTCPDestPort{})
RegisterTestCase(FilterOutputDropTCPSrcPort{})
+ RegisterTestCase(FilterOutputDestination{})
+ RegisterTestCase(FilterOutputInvertDestination{})
}
-// FilterOutputDropTCPDestPort tests that connections are not accepted on specified source ports.
+// FilterOutputDropTCPDestPort tests that connections are not accepted on
+// specified source ports.
type FilterOutputDropTCPDestPort struct{}
// Name implements TestCase.Name.
@@ -48,14 +51,15 @@ func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
}
return nil
}
-// FilterOutputDropTCPSrcPort tests that connections are not accepted on specified source ports.
+// FilterOutputDropTCPSrcPort tests that connections are not accepted on
+// specified source ports.
type FilterOutputDropTCPSrcPort struct{}
// Name implements TestCase.Name.
@@ -79,9 +83,63 @@ func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
+ if err := connectTCP(ip, dropPort, sendloopDuration); err == nil {
return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
}
return nil
}
+
+// FilterOutputDestination tests that we can selectively allow packets to
+// certain destinations.
+type FilterOutputDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDestination) Name() string {
+ return "FilterOutputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDestination) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDestination) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInvertDestination tests that we can selectively allow packets
+// not headed for a particular destination.
+type FilterOutputInvertDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertDestination) Name() string {
+ return "FilterOutputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "!", "-d", localIP, "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertDestination) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 0621861eb..7f1f70606 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -191,17 +191,37 @@ func TestFilterInputDropOnlyUDP(t *testing.T) {
}
func TestNATRedirectUDPPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(NATRedirectUDPPort{}); err != nil {
t.Fatal(err)
}
}
+func TestNATRedirectTCPPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATRedirectTCPPort{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestNATDropUDP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(NATDropUDP{}); err != nil {
t.Fatal(err)
}
}
+func TestNATAcceptAll(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATAcceptAll{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestFilterInputDropTCPDestPort(t *testing.T) {
if err := singleTest(FilterInputDropTCPDestPort{}); err != nil {
t.Fatal(err)
@@ -239,12 +259,16 @@ func TestFilterInputReturnUnderflow(t *testing.T) {
}
func TestFilterOutputDropTCPDestPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(FilterOutputDropTCPDestPort{}); err != nil {
t.Fatal(err)
}
}
func TestFilterOutputDropTCPSrcPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(FilterOutputDropTCPSrcPort{}); err != nil {
t.Fatal(err)
}
@@ -285,3 +309,83 @@ func TestJumpTwice(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestInputDestination(t *testing.T) {
+ if err := singleTest(FilterInputDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestInputInvertDestination(t *testing.T) {
+ if err := singleTest(FilterInputInvertDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestOutputDestination(t *testing.T) {
+ if err := singleTest(FilterOutputDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestOutputInvertDestination(t *testing.T) {
+ if err := singleTest(FilterOutputInvertDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATOutRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATOutRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATOutDontRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATOutDontRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATOutRedirectInvert(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATOutRedirectInvert{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATPreRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATPreRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATPreDontRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATPreDontRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATPreRedirectInvert(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATPreRedirectInvert{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATRedirectRequiresProtocol(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATRedirectRequiresProtocol{}); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 32cf5a417..134391e8d 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -24,6 +24,7 @@ import (
)
const iptablesBinary = "iptables"
+const localIP = "127.0.0.1"
// filterTable calls `iptables -t filter` with the given args.
func filterTable(args ...string) error {
@@ -46,8 +47,17 @@ func tableCmd(table string, args []string) error {
// filterTableRules is like filterTable, but runs multiple iptables commands.
func filterTableRules(argsList [][]string) error {
+ return tableRules("filter", argsList)
+}
+
+// natTableRules is like natTable, but runs multiple iptables commands.
+func natTableRules(argsList [][]string) error {
+ return tableRules("nat", argsList)
+}
+
+func tableRules(table string, argsList [][]string) error {
for _, args := range argsList {
- if err := filterTable(args...); err != nil {
+ if err := tableCmd(table, args); err != nil {
return err
}
}
@@ -125,27 +135,37 @@ func listenTCP(port int, timeout time.Duration) error {
return nil
}
-// connectTCP connects the TCP server over specified local port, server IP and remote/server port.
-func connectTCP(ip net.IP, remotePort, localPort int, timeout time.Duration) error {
+// connectTCP connects to the given IP and port from an ephemeral local address.
+func connectTCP(ip net.IP, port int, timeout time.Duration) error {
contAddr := net.TCPAddr{
IP: ip,
- Port: remotePort,
+ Port: port,
}
// The container may not be listening when we first connect, so retry
// upon error.
callback := func() error {
- localAddr := net.TCPAddr{
- Port: localPort,
- }
- conn, err := net.DialTCP("tcp4", &localAddr, &contAddr)
+ conn, err := net.DialTimeout("tcp", contAddr.String(), timeout)
if conn != nil {
conn.Close()
}
return err
}
if err := testutil.Poll(callback, timeout); err != nil {
- return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err)
+ return fmt.Errorf("timed out waiting to connect IP, most recent error: %v", err)
}
return nil
}
+
+// localAddrs returns a list of local network interface addresses.
+func localAddrs() ([]string, error) {
+ addrs, err := net.InterfaceAddrs()
+ if err != nil {
+ return nil, err
+ }
+ addrStrs := make([]string, 0, len(addrs))
+ for _, addr := range addrs {
+ addrStrs = append(addrStrs, addr.String())
+ }
+ return addrStrs, nil
+}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index a01117ec8..40096901c 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -15,8 +15,10 @@
package iptables
import (
+ "errors"
"fmt"
"net"
+ "time"
)
const (
@@ -25,7 +27,16 @@ const (
func init() {
RegisterTestCase(NATRedirectUDPPort{})
+ RegisterTestCase(NATRedirectTCPPort{})
RegisterTestCase(NATDropUDP{})
+ RegisterTestCase(NATAcceptAll{})
+ RegisterTestCase(NATPreRedirectIP{})
+ RegisterTestCase(NATPreDontRedirectIP{})
+ RegisterTestCase(NATPreRedirectInvert{})
+ RegisterTestCase(NATOutRedirectIP{})
+ RegisterTestCase(NATOutDontRedirectIP{})
+ RegisterTestCase(NATOutRedirectInvert{})
+ RegisterTestCase(NATRedirectRequiresProtocol{})
}
// NATRedirectUDPPort tests that packets are redirected to different port.
@@ -45,6 +56,7 @@ func (NATRedirectUDPPort) ContainerAction(ip net.IP) error {
if err := listenUDP(redirectPort, sendloopDuration); err != nil {
return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err)
}
+
return nil
}
@@ -53,7 +65,31 @@ func (NATRedirectUDPPort) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
-// NATDropUDP tests that packets are not received in ports other than redirect port.
+// NATRedirectTCPPort tests that connections are redirected on specified ports.
+type NATRedirectTCPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATRedirectTCPPort) Name() string {
+ return "NATRedirectTCPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATRedirectTCPPort) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on redirect port.
+ return listenTCP(redirectPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATRedirectTCPPort) LocalAction(ip net.IP) error {
+ return connectTCP(ip, dropPort, sendloopDuration)
+}
+
+// NATDropUDP tests that packets are not received in ports other than redirect
+// port.
type NATDropUDP struct{}
// Name implements TestCase.Name.
@@ -78,3 +114,218 @@ func (NATDropUDP) ContainerAction(ip net.IP) error {
func (NATDropUDP) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
+
+// NATAcceptAll tests that all UDP packets are accepted.
+type NATAcceptAll struct{}
+
+// Name implements TestCase.Name.
+func (NATAcceptAll) Name() string {
+ return "NATAcceptAll"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATAcceptAll) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ if err := listenUDP(acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATAcceptAll) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATOutRedirectIP uses iptables to select packets based on destination IP and
+// redirects them.
+type NATOutRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectIP) Name() string {
+ return "NATOutRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectIP) ContainerAction(ip net.IP) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ dest := net.IP([]byte{200, 0, 0, 2})
+ return loopbackTest(dest, "-A", "OUTPUT", "-d", dest.String(), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectIP) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATOutDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATOutDontRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATOutDontRedirectIP) Name() string {
+ return "NATOutDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutDontRedirectIP) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "OUTPUT", "-d", localIP, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutDontRedirectIP) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// NATOutRedirectInvert tests that iptables can match with "! -d".
+type NATOutRedirectInvert struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectInvert) Name() string {
+ return "NATOutRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectInvert) ContainerAction(ip net.IP) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ dest := []byte{200, 0, 0, 3}
+ destStr := "200.0.0.2"
+ return loopbackTest(dest, "-A", "OUTPUT", "!", "-d", destStr, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectInvert) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATPreRedirectIP tests that we can use iptables to select packets based on
+// destination IP and redirect them.
+type NATPreRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectIP) Name() string {
+ return "NATPreRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectIP) ContainerAction(ip net.IP) error {
+ addrs, err := localAddrs()
+ if err != nil {
+ return err
+ }
+
+ var rules [][]string
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)})
+ }
+ if err := natTableRules(rules); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectIP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// NATPreDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATPreDontRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATPreDontRedirectIP) Name() string {
+ return "NATPreDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreDontRedirectIP) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreDontRedirectIP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATPreRedirectInvert tests that iptables can match with "! -d".
+type NATPreRedirectInvert struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectInvert) Name() string {
+ return "NATPreRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectInvert) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "!", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectInvert) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a
+// protocol to be specified with -p.
+type NATRedirectRequiresProtocol struct{}
+
+// Name implements TestCase.Name.
+func (NATRedirectRequiresProtocol) Name() string {
+ return "NATRedirectRequiresProtocol"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil {
+ return errors.New("expected an error using REDIRECT --to-ports without a protocol")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// loopbackTests runs an iptables rule and ensures that packets sent to
+// dest:dropPort are received by localhost:acceptPort.
+func loopbackTest(dest net.IP, args ...string) error {
+ if err := natTable(args...); err != nil {
+ return err
+ }
+ sendCh := make(chan error)
+ listenCh := make(chan error)
+ go func() {
+ sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration)
+ }()
+ go func() {
+ listenCh <- listenUDP(acceptPort, sendloopDuration)
+ }()
+ select {
+ case err := <-listenCh:
+ if err != nil {
+ return err
+ }
+ case <-time.After(sendloopDuration):
+ return errors.New("timed out")
+ }
+ // sendCh will always take the full sendloop time.
+ return <-sendCh
+}
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
index d113555b1..fb0b2db41 100644
--- a/test/packetdrill/BUILD
+++ b/test/packetdrill/BUILD
@@ -1,8 +1,48 @@
-load("defs.bzl", "packetdrill_test")
+load("defs.bzl", "packetdrill_linux_test", "packetdrill_netstack_test", "packetdrill_test")
package(licenses = ["notice"])
packetdrill_test(
- name = "fin_wait2_timeout",
+ name = "packetdrill_sanity_test",
+ scripts = ["sanity_test.pkt"],
+)
+
+packetdrill_test(
+ name = "accept_ack_drop_test",
+ scripts = ["accept_ack_drop.pkt"],
+)
+
+packetdrill_test(
+ name = "fin_wait2_timeout_test",
scripts = ["fin_wait2_timeout.pkt"],
)
+
+packetdrill_linux_test(
+ name = "tcp_user_timeout_test_linux_test",
+ scripts = ["linux/tcp_user_timeout.pkt"],
+)
+
+packetdrill_netstack_test(
+ name = "tcp_user_timeout_test_netstack_test",
+ scripts = ["netstack/tcp_user_timeout.pkt"],
+)
+
+packetdrill_test(
+ name = "listen_close_before_handshake_complete_test",
+ scripts = ["listen_close_before_handshake_complete.pkt"],
+)
+
+packetdrill_test(
+ name = "no_rst_to_rst_test",
+ scripts = ["no_rst_to_rst.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_test",
+ scripts = ["tcp_defer_accept.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_timeout_test",
+ scripts = ["tcp_defer_accept_timeout.pkt"],
+)
diff --git a/test/packetdrill/Dockerfile b/test/packetdrill/Dockerfile
index bd4451355..4b75e9527 100644
--- a/test/packetdrill/Dockerfile
+++ b/test/packetdrill/Dockerfile
@@ -1,9 +1,9 @@
FROM ubuntu:bionic
-RUN apt-get update
-RUN apt-get install -y net-tools git iptables iputils-ping netcat tcpdump jq tar
+RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \
+ netcat tcpdump jq tar bison flex make
RUN hash -r
RUN git clone --branch packetdrill-v2.0 \
https://github.com/google/packetdrill.git
-RUN cd packetdrill/gtests/net/packetdrill && ./configure && \
- apt-get install -y bison flex make && make
+RUN cd packetdrill/gtests/net/packetdrill && ./configure && make
+CMD /bin/bash
diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt
new file mode 100644
index 000000000..76e638fd4
--- /dev/null
+++ b/test/packetdrill/accept_ack_drop.pkt
@@ -0,0 +1,27 @@
+// Test that the accept works if the final ACK is dropped and an ack with data
+// follows the dropped ack.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
++0.0 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl
index 8623ce7b1..f499c177b 100644
--- a/test/packetdrill/defs.bzl
+++ b/test/packetdrill/defs.bzl
@@ -66,7 +66,7 @@ def packetdrill_linux_test(name, **kwargs):
if "tags" not in kwargs:
kwargs["tags"] = _PACKETDRILL_TAGS
_packetdrill_test(
- name = name + "_linux_test",
+ name = name,
flags = ["--dut_platform", "linux"],
**kwargs
)
@@ -75,13 +75,13 @@ def packetdrill_netstack_test(name, **kwargs):
if "tags" not in kwargs:
kwargs["tags"] = _PACKETDRILL_TAGS
_packetdrill_test(
- name = name + "_netstack_test",
+ name = name,
# This is the default runtime unless
# "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"],
**kwargs
)
-def packetdrill_test(**kwargs):
- packetdrill_linux_test(**kwargs)
- packetdrill_netstack_test(**kwargs)
+def packetdrill_test(name, **kwargs):
+ packetdrill_linux_test(name + "_linux_test", **kwargs)
+ packetdrill_netstack_test(name + "_netstack_test", **kwargs)
diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt
index 613f0bec9..93ab08575 100644
--- a/test/packetdrill/fin_wait2_timeout.pkt
+++ b/test/packetdrill/fin_wait2_timeout.pkt
@@ -19,5 +19,5 @@
+0 > F. 1:1(0) ack 1 <...>
+0 < . 1:1(0) ack 2 win 257
-+1.1 < . 1:1(0) ack 2 win 257
++2 < . 1:1(0) ack 2 win 257
+0 > R 2:2(0) win 0
diff --git a/test/packetdrill/linux/tcp_user_timeout.pkt b/test/packetdrill/linux/tcp_user_timeout.pkt
new file mode 100644
index 000000000..38018cb42
--- /dev/null
+++ b/test/packetdrill/linux/tcp_user_timeout.pkt
@@ -0,0 +1,39 @@
+// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection
+// if there is pending unacked data after the user specified timeout.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0.1 < . 1:1(0) ack 1 win 32792
+
++0.100 accept(3, ..., ...) = 4
+
+// Okay, we received nothing, and decide to close this idle socket.
+// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth
+// trying hard to cleanly close this flow, at the price of keeping
+// a TCP structure in kernel for about 1 minute!
++2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0
+
+// The write/ack is required mainly for netstack as netstack does
+// not update its RTO during the handshake.
++0 write(4, ..., 100) = 100
++0 > P. 1:101(100) ack 1 <...>
++0 < . 1:1(0) ack 101 win 32792
+
++0 close(4) = 0
+
++0 > F. 101:101(0) ack 1 <...>
++.3~+.400 > F. 101:101(0) ack 1 <...>
++.3~+.400 > F. 101:101(0) ack 1 <...>
++.6~+.800 > F. 101:101(0) ack 1 <...>
++1.2~+1.300 > F. 101:101(0) ack 1 <...>
+
+// We finally receive something from the peer, but it is way too late
+// Our socket vanished because TCP_USER_TIMEOUT was really small.
++.1 < . 1:2(1) ack 102 win 32792
++0 > R 102:102(0) win 0
diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt
new file mode 100644
index 000000000..51c3f1a32
--- /dev/null
+++ b/test/packetdrill/listen_close_before_handshake_complete.pkt
@@ -0,0 +1,31 @@
+// Test that closing a listening socket closes any connections in SYN-RCVD
+// state and any packets bound for these connections generate a RESET.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
+
++0.100 close(3) = 0
++0.1 < P. 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.0 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/netstack/tcp_user_timeout.pkt b/test/packetdrill/netstack/tcp_user_timeout.pkt
new file mode 100644
index 000000000..60103adba
--- /dev/null
+++ b/test/packetdrill/netstack/tcp_user_timeout.pkt
@@ -0,0 +1,38 @@
+// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection
+// if there is pending unacked data after the user specified timeout.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0.1 < . 1:1(0) ack 1 win 32792
+
++0.100 accept(3, ..., ...) = 4
+
+// Okay, we received nothing, and decide to close this idle socket.
+// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth
+// trying hard to cleanly close this flow, at the price of keeping
+// a TCP structure in kernel for about 1 minute!
++2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0
+
+// The write/ack is required mainly for netstack as netstack does
+// not update its RTO during the handshake.
++0 write(4, ..., 100) = 100
++0 > P. 1:101(100) ack 1 <...>
++0 < . 1:1(0) ack 101 win 32792
+
++0 close(4) = 0
+
++0 > F. 101:101(0) ack 1 <...>
++.2~+.300 > F. 101:101(0) ack 1 <...>
++.4~+.500 > F. 101:101(0) ack 1 <...>
++.8~+.900 > F. 101:101(0) ack 1 <...>
+
+// We finally receive something from the peer, but it is way too late
+// Our socket vanished because TCP_USER_TIMEOUT was really small.
++1.61 < . 1:2(1) ack 102 win 32792
++0 > R 102:102(0) win 0
diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt
new file mode 100644
index 000000000..612747827
--- /dev/null
+++ b/test/packetdrill/no_rst_to_rst.pkt
@@ -0,0 +1,36 @@
+// Test a RST is not generated in response to a RST and a RST is correctly
+// generated when an accepted endpoint is RST due to an incoming RST.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0 < P. 1:1(0) ack 1 win 257
+
++0.100 accept(3, ..., ...) = 4
+
++0.200 < R 1:1(0) win 0
+
++0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer)
+
++0.00 < . 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.00 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh
index 0b22dfd5c..c8268170f 100755
--- a/test/packetdrill/packetdrill_test.sh
+++ b/test/packetdrill/packetdrill_test.sh
@@ -91,8 +91,8 @@ fi
# Variables specific to the test runner start with TEST_RUNNER_.
declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill"
# Use random numbers so that test networks don't collide.
-declare -r CTRL_NET="ctrl_net-${RANDOM}${RANDOM}"
-declare -r TEST_NET="test_net-${RANDOM}${RANDOM}"
+declare -r CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)"
+declare -r TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)"
declare -r tolerance_usecs=100000
# On both DUT and test runner, testing packets are on the eth2 interface.
declare -r TEST_DEVICE="eth2"
diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
new file mode 100644
index 000000000..a86b90ce6
--- /dev/null
+++ b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
@@ -0,0 +1,9 @@
+// Test that a listening socket generates a RST when it receives an
+// ACK and syn cookies are not in use.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
++0.1 < . 1:1(0) ack 1 win 32792
++0 > R 1:1(0) ack 0 win 0 \ No newline at end of file
diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt
new file mode 100644
index 000000000..b3b58c366
--- /dev/null
+++ b/test/packetdrill/sanity_test.pkt
@@ -0,0 +1,7 @@
+// Basic sanity test. One system call.
+//
+// All of the plumbing has to be working however, and the packetdrill wire
+// client needs to be able to connect to the wire server and send the script,
+// probe local interfaces, run through the test w/ timings, etc.
+
+0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt
new file mode 100644
index 000000000..a17f946db
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT
+// timeout is not hit but an ACK w/ data does complete and deliver the
+// connection to the accept queue.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// Now send an ACK w/ data. This should complete the connection
+// and deliver the socket to the accept queue.
++0.1 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt
new file mode 100644
index 000000000..201fdeb14
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept_timeout.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout
+// is hit and a connection is delivered.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// We should see one more retransmit of the SYN-ACK as a last ditch
+// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another
+// ACK or a packet with data.
++.35~+2.35 > S. 0:0(0) ack 1 <...>
+
+// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The ACK above should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 1 <...>
++0.0 < F. 1:1(0) ack 2 win 257
++0.01 > . 2:2(0) ack 2 <...>
diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md
new file mode 100644
index 000000000..ece4dedc6
--- /dev/null
+++ b/test/packetimpact/README.md
@@ -0,0 +1,531 @@
+# Packetimpact
+
+## What is packetimpact?
+
+Packetimpact is a tool for platform-independent network testing. It is heavily
+inspired by [packetdrill](https://github.com/google/packetdrill). It creates two
+docker containers connected by a network. One is for the test bench, which
+operates the test. The other is for the device-under-test (DUT), which is the
+software being tested. The test bench communicates over the network with the DUT
+to check correctness of the network.
+
+### Goals
+
+Packetimpact aims to provide:
+
+* A **multi-platform** solution that can test both Linux and gVisor.
+* **Conciseness** on par with packetdrill scripts.
+* **Control-flow** like for loops, conditionals, and variables.
+* **Flexibilty** to specify every byte in a packet or use multiple sockets.
+
+## When to use packetimpact?
+
+There are a few ways to write networking tests for gVisor currently:
+
+* [Go unit tests](https://github.com/google/gvisor/tree/master/pkg/tcpip)
+* [syscall tests](https://github.com/google/gvisor/tree/master/test/syscalls/linux)
+* [packetdrill tests](https://github.com/google/gvisor/tree/master/test/packetdrill)
+* packetimpact tests
+
+The right choice depends on the needs of the test.
+
+Feature | Go unit test | syscall test | packetdrill | packetimpact
+------------- | ------------ | ------------ | ----------- | ------------
+Multiplatform | no | **YES** | **YES** | **YES**
+Concise | no | somewhat | somewhat | **VERY**
+Control-flow | **YES** | **YES** | no | **YES**
+Flexible | **VERY** | no | somewhat | **VERY**
+
+### Go unit tests
+
+If the test depends on the internals of gVisor and doesn't need to run on Linux
+or other platforms for comparison purposes, a Go unit test can be appropriate.
+They can observe internals of gVisor networking. The downside is that they are
+**not concise** and **not multiplatform**. If you require insight on gVisor
+internals, this is the right choice.
+
+### Syscall tests
+
+Syscall tests are **multiplatform** but cannot examine the internals of gVisor
+networking. They are **concise**. They can use **control-flow** structures like
+conditionals, for loops, and variables. However, they are limited to only what
+the POSIX interface provides so they are **not flexible**. For example, you
+would have difficulty writing a syscall test that intentionally sends a bad IP
+checksum. Or if you did write that test with raw sockets, it would be very
+**verbose** to write a test that intentionally send wrong checksums, wrong
+protocols, wrong sequence numbers, etc.
+
+### Packetdrill tests
+
+Packetdrill tests are **multiplatform** and can run against both Linux and
+gVisor. They are **concise** and use a special packetdrill scripting language.
+They are **more flexible** than a syscall test in that they can send packets
+that a syscall test would have difficulty sending, like a packet with a
+calcuated ACK number. But they are also somewhat limimted in flexibiilty in that
+they can't do tests with multiple sockets. They have **no control-flow** ability
+like variables or conditionals. For example, it isn't possible to send a packet
+that depends on the window size of a previous packet because the packetdrill
+language can't express that. Nor could you branch based on whether or not the
+other side supports window scaling, for example.
+
+### Packetimpact tests
+
+Packetimpact tests are similar to Packetdrill tests except that they are written
+in Go instead of the packetdrill scripting language. That gives them all the
+**control-flow** abilities of Go (loops, functions, variables, etc). They are
+**multiplatform** in the same way as packetdrill tests but even more
+**flexible** because Go is more expressive than the scripting language of
+packetdrill. However, Go is **not as concise** as the packetdrill language. Many
+design decisions below are made to mitigate that.
+
+## How it works
+
+```
+ +--------------+ +--------------+
+ | | TEST NET | |
+ | | <===========> | Device |
+ | Test | | Under |
+ | Bench | | Test |
+ | | <===========> | (DUT) |
+ | | CONTROL NET | |
+ +--------------+ +--------------+
+```
+
+Two docker containers are created by a script, one for the test bench and the
+other for the device under test (DUT). The script connects the two containers
+with a control network and test network. It also does some other tasks like
+waiting until the DUT is ready before starting the test and disabling Linux
+networking that would interfere with the test bench.
+
+### DUT
+
+The DUT container runs a program called the "posix_server". The posix_server is
+written in c++ for maximum portability. It is compiled on the host. The script
+that starts the containers copies it into the DUT's container and runs it. It's
+job is to receive directions from the test bench on what actions to take. For
+this, the posix_server does three steps in a loop:
+
+1. Listen for a request from the test bench.
+2. Execute a command.
+3. Send the response back to the test bench.
+
+The requests and responses are
+[protobufs](https://developers.google.com/protocol-buffers) and the
+communication is done with [gRPC](https://grpc.io/). The commands run are
+[POSIX socket commands](https://en.wikipedia.org/wiki/Berkeley_sockets#Socket_API_functions),
+with the inputs and outputs converted into protobuf requests and responses. All
+communication is on the control network, so that the test network is unaffected
+by extra packets.
+
+For example, this is the request and response pair to call
+[`socket()`](http://man7.org/linux/man-pages/man2/socket.2.html):
+
+```protocol-buffer
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2;
+}
+```
+
+##### Alternatives considered
+
+* We could have use JSON for communication instead. It would have been a
+ lighter-touch than protobuf but protobuf handles all the data type and has
+ strict typing to prevent a class of errors. The test bench could be written
+ in other languages, too.
+* Instead of mimicking the POSIX interfaces, arguments could have had a more
+ natural form, like the `bind()` getting a string IP address instead of bytes
+ in a `sockaddr_t`. However, conforming to the existing structures keeps more
+ of the complexity in Go and keeps the posix_server simpler and thus more
+ likely to compile everywhere.
+
+### Test Bench
+
+The test bench does most of the work in a test. It is a Go program that compiles
+on the host and is copied by the script into test bench's container. It is a
+regular [go unit test](https://golang.org/pkg/testing/) that imports the test
+bench framework. The test bench framwork is based on three basic utilities:
+
+* Commanding the DUT to run POSIX commands and return responses.
+* Sending raw packets to the DUT on the test network.
+* Listening for raw packets from the DUT on the test network.
+
+#### DUT commands
+
+To keep the interface to the DUT consistent and easy-to-use, each POSIX command
+supported by the posix_server is wrapped in functions with signatures similar to
+the ones in the [Go unix package](https://godoc.org/golang.org/x/sys/unix). This
+way all the details of endianess and (un)marshalling of go structs such as
+[unix.Timeval](https://godoc.org/golang.org/x/sys/unix#Timeval) is handled in
+one place. This also makes it straight-forward to convert tests that use `unix.`
+or `syscall.` calls to `dut.` calls.
+
+For example, creating a connection to the DUT and commanding it to make a socket
+looks like this:
+
+```go
+dut := testbench.NewDut(t)
+fd, err := dut.SocketWithErrno(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+if fd < 0 {
+ t.Fatalf(...)
+}
+```
+
+Because the usual case is to fail the test when the DUT fails to create a
+socket, there is a concise version of each of the `...WithErrno` functions that
+does that:
+
+```go
+dut := testbench.NewDut(t)
+fd := dut.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+```
+
+The DUT and other structs in the code store a `*testing.T` so that they can
+provide versions of functions that call `t.Fatalf(...)`. This helps keep tests
+concise.
+
+##### Alternatives considered
+
+* Instead of mimicking the `unix.` go interface, we could have invented a more
+ natural one, like using `float64` instead of `Timeval`. However, using the
+ same function signatures that `unix.` has makes it easier to convert code to
+ `dut.`. Also, using an existing interface ensures that we don't invent an
+ interface that isn't extensible. For example, if we invented a function for
+ `bind()` that didn't support IPv6 and later we had to add a second `bind6()`
+ function.
+
+#### Sending/Receiving Raw Packets
+
+The framework wraps POSIX sockets for sending and receiving raw frames. Both
+send and receive are synchronous commands.
+[SO_RCVTIMEO](http://man7.org/linux/man-pages/man7/socket.7.html) is used to set
+a timeout on the receive commands. For ease of use, these are wrapped in an
+`Injector` and a `Sniffer`. They have functions:
+
+```go
+func (s *Sniffer) Recv(timeout time.Duration) []byte {...}
+func (i *Injector) Send(b []byte) {...}
+```
+
+##### Alternatives considered
+
+* [gopacket](https://github.com/google/gopacket) pcap has raw socket support
+ but requires cgo. cgo is not guaranteed to be portable from the host to the
+ container and in practice, the container doesn't recognize binaries built on
+ the host if they use cgo.
+* Both gVisor and gopacket have the ability to read and write pcap files
+ without cgo but that is insufficient here.
+* The sniffer and injector can't share a socket because they need to be bound
+ differently.
+* Sniffing could have been done asynchronously with channels, obviating the
+ need for `SO_RCVTIMEO`. But that would introduce asynchronous complication.
+ `SO_RCVTIMEO` is well supported on the test bench.
+
+#### `Layer` struct
+
+A large part of packetimpact tests is creating packets to send and comparing
+received packets against expectations. To keep tests concise, it is useful to be
+able to specify just the important parts of packets that need to be set. For
+example, sending a packet with default values except for TCP Flags. And for
+packets received, it's useful to be able to compare just the necessary parts of
+received packets and ignore the rest.
+
+To aid in both of those, Go structs with optional fields are created for each
+encapsulation type, such as IPv4, TCP, and Ethernet. This is inspired by
+[scapy](https://scapy.readthedocs.io/en/latest/). For example, here is the
+struct for Ethernet:
+
+```go
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+```
+
+Each struct has the same fields as those in the
+[gVisor headers](https://github.com/google/gvisor/tree/master/pkg/tcpip/header)
+but with a pointer for each field that may be `nil`.
+
+##### Alternatives considered
+
+* Just use []byte like gVisor headers do. The drawback is that it makes the
+ tests more verbose.
+ * For example, there would be no way to call `Send(myBytes)` concisely and
+ indicate if the checksum should be calculated automatically versus
+ overridden. The only way would be to add lines to the test to calculate
+ it before each Send, which is wordy. Or make multiple versions of Send:
+ one that checksums IP, one that doesn't, one that checksums TCP, one
+ that does both, etc. That would be many combinations.
+ * Filtering inputs would become verbose. Either:
+ * large conditionals that need to be repeated many places:
+ `h[FlagOffset] == SYN && h[LengthOffset:LengthOffset+2] == ...` or
+ * Many functions, one per field, like: `filterByFlag(myBytes, SYN)`,
+ `filterByLength(myBytes, 20)`, `filterByNextProto(myBytes, 0x8000)`,
+ etc.
+ * Using pointers allows us to combine `Layer`s with a one-line call to
+ `mergo.Merge(...)`. So the default `Layers` can be overridden by a
+ `Layers` with just the TCP conection's src/dst which can be overridden
+ by one with just a test specific TCP window size. Each override is
+ specified as just one call to `mergo.Merge`.
+ * It's a proven way to separate the details of a packet from the byte
+ format as shown by scapy's success.
+* Use packetgo. It's more general than parsing packets with gVisor. However:
+ * packetgo doesn't have optional fields so many of the above problems
+ still apply.
+ * It would be yet another dependency.
+ * It's not as well known to engineers that are already writing gVisor
+ code.
+ * It might be a good candidate for replacing the parsing of packets into
+ `Layer`s if all that parsing turns out to be more work than parsing by
+ packetgo and converting *that* to `Layer`. packetgo has easier to use
+ getters for the layers. This could be done later in a way that doesn't
+ break tests.
+
+#### `Layer` methods
+
+The `Layer` structs provide a way to partially specify an encapsulation. They
+also need methods for using those partially specified encapsulation, for example
+to marshal them to bytes or compare them. For those, each encapsulation
+implements the `Layer` interface:
+
+```go
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ // toBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ toBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+}
+```
+
+For each `Layer` there is also a parsing function. For example, this one is for
+Ethernet:
+
+```
+func ParseEther(b []byte) (Layers, error)
+```
+
+The parsing function converts bytes received on the wire into a `Layer`
+(actually `Layers`, see below) which has no `nil`s in it. By using
+`match(Layer)` to compare against another `Layer` that *does* have `nil`s in it,
+the received bytes can be partially compared. The `nil`s behave as
+"don't-cares".
+
+##### Alternatives considered
+
+* Matching against `[]byte` instead of converting to `Layer` first.
+ * The downside is that it precludes the use of a `cmp.Equal` one-liner to
+ do comparisons.
+ * It creates confusion in the code to deal with both representations at
+ different times. For example, is the checksum calculated on `[]byte` or
+ `Layer` when sending? What about when checking received packets?
+
+#### `Layers`
+
+```
+type Layers []Layer
+
+func (ls *Layers) match(other Layers) bool {...}
+func (ls *Layers) toBytes() ([]byte, error) {...}
+```
+
+`Layers` is an array of `Layer`. It represents a stack of encapsulations, such
+as `Layers{Ether{},IPv4{},TCP{},Payload{}}`. It also has `toBytes()` and
+`match(Layers)`, like `Layer`. The parse functions above actually return
+`Layers` and not `Layer` because they know about the headers below and
+sequentially call each parser on the remaining, encapsulated bytes.
+
+All this leads to the ability to write concise packet processing. For example:
+
+```go
+etherType := 0x8000
+flags = uint8(header.TCPFlagSyn|header.TCPFlagAck)
+toMatch := Layers{Ether{Type: &etherType}, IPv4{}, TCP{Flags: &flags}}
+for {
+ recvBytes := sniffer.Recv(time.Second)
+ if recvBytes == nil {
+ println("Got no packet for 1 second")
+ }
+ gotPacket, err := ParseEther(recvBytes)
+ if err == nil && toMatch.match(gotPacket) {
+ println("Got a TCP/IPv4/Eth packet with SYNACK")
+ }
+}
+```
+
+##### Alternatives considered
+
+* Don't use previous and next pointers.
+ * Each layer may need to be able to interrogate the layers aroung it, like
+ for computing the next protocol number or total length. So *some*
+ mechanism is needed for a `Layer` to see neighboring layers.
+ * We could pass the entire array `Layers` to the `toBytes()` function.
+ Passing an array to a method that includes in the array the function
+ receiver itself seems wrong.
+
+#### Connections
+
+Using `Layers` above, we can create connection structures to maintain state
+about connections. For example, here is the `TCPIPv4` struct:
+
+```
+type TCPIPv4 struct {
+ outgoing Layers
+ incoming Layers
+ localSeqNum uint32
+ remoteSeqNum uint32
+ sniffer Sniffer
+ injector Injector
+ t *testing.T
+}
+```
+
+`TCPIPv4` contains an `outgoing Layers` which holds the defaults for the
+connection, such as the source and destination MACs, IPs, and ports. When
+`outgoing.toBytes()` is called a valid packet for this TCPIPv4 flow is built.
+
+It also contains `incoming Layers` which holds filter for incoming packets that
+belong to this flow. `incoming.match(Layers)` is used on received bytes to check
+if they are part of the flow.
+
+The `sniffer` and `injector` are for receiving and sending raw packet bytes. The
+`localSeqNum` and `remoteSeqNum` are updated by `Send` and `Recv` so that
+outgoing packets will have, by default, the correct sequence number and ack
+number.
+
+TCPIPv4 provides some functions:
+
+```
+func (conn *TCPIPv4) Send(tcp TCP) {...}
+func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {...}
+```
+
+`Send(tcp TCP)` uses [mergo](https://github.com/imdario/mergo) to merge the
+provided `TCP` (a `Layer`) into `outgoing`. This way the user can specify
+concisely just which fields of `outgoing` to modify. The packet is sent using
+the `injector`.
+
+`Recv(timeout time.Duration)` reads packets from the sniffer until either the
+timeout has elapsed or a packet that matches `incoming` arrives.
+
+Using those, we can perform a TCP 3-way handshake without too much code:
+
+```go
+func (conn *TCPIPv4) Handshake() {
+ syn := uint8(header.TCPFlagSyn)
+ synack := uint8(header.TCPFlagSyn)
+ ack := uint8(header.TCPFlagAck)
+ conn.Send(TCP{Flags: &syn}) // Send a packet with all defaults but set TCP-SYN.
+
+ // Wait for the SYN-ACK response.
+ for {
+ newTCP := conn.Recv(time.Second) // This already filters by MAC, IP, and ports.
+ if TCP{Flags: &synack}.match(newTCP) {
+ break // Only if it's a SYN-ACK proceed.
+ }
+ }
+
+ conn.Send(TCP{Flags: &ack}) // Send an ACK. The seq and ack numbers are set correctly.
+}
+```
+
+The handshake code is part of the testbench utilities so tests can share this
+common sequence, making tests even more concise.
+
+##### Alternatives considered
+
+* Instead of storing `outgoing` and `incoming`, store values.
+ * There would be many more things to store instead, like `localMac`,
+ `remoteMac`, `localIP`, `remoteIP`, `localPort`, and `remotePort`.
+ * Construction of a packet would be many lines to copy each of these
+ values into a `[]byte`. And there would be slight variations needed for
+ each encapsulation stack, like TCPIPv6 and ARP.
+ * Filtering incoming packets would be a long sequence:
+ * Compare the MACs, then
+ * Parse the next header, then
+ * Compare the IPs, then
+ * Parse the next header, then
+ * Compare the TCP ports. Instead it's all just one call to
+ `cmp.Equal(...)`, for all sequences.
+ * A TCPIPv6 connection could share most of the code. Only the type of the
+ IP addresses are different. The types of `outgoing` and `incoming` would
+ be remain `Layers`.
+ * An ARP connection could share all the Ethernet parts. The IP `Layer`
+ could be factored out of `outgoing`. After that, the IPv4 and IPv6
+ connections could implement one interface and a single TCP struct could
+ have either network protocol through composition.
+
+## Putting it all together
+
+Here's what te start of a packetimpact unit test looks like. This test creates a
+TCP connection with the DUT. There are added comments for explanation in this
+document but a real test might not include them in order to stay even more
+concise.
+
+```go
+func TestMyTcpTest(t *testing.T) {
+ // Prepare a DUT for communication.
+ dut := testbench.NewDUT(t)
+
+ // This does:
+ // dut.Socket()
+ // dut.Bind()
+ // dut.Getsockname() to learn the new port number
+ // dut.Listen()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD) // Tell the DUT to close the socket at the end of the test.
+
+ // Monitor a new TCP connection with sniffer, injector, sequence number tracking,
+ // and reasonable outgoing and incoming packet field default IPs, MACs, and port numbers.
+ conn := testbench.NewTCPIPv4(t, dut, remotePort)
+
+ // Perform a 3-way handshake: send SYN, expect SYNACK, send ACK.
+ conn.Handshake()
+
+ // Tell the DUT to accept the new connection.
+ acceptFD := dut.Accept(acceptFd)
+}
+```
+
+## Other notes
+
+* The time between receiving a SYN-ACK and replying with an ACK in `Handshake`
+ is about 3ms. This is much slower than the native unix response, which is
+ about 0.3ms. Packetdrill gets closer to 0.3ms. For tests where timing is
+ crucial, packetdrill is faster and more precise.
diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD
new file mode 100644
index 000000000..3ce63c2c6
--- /dev/null
+++ b/test/packetimpact/dut/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "cc_binary", "grpcpp")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "posix_server",
+ srcs = ["posix_server.cc"],
+ linkstatic = 1,
+ static = True, # This is needed for running in a docker container.
+ deps = [
+ grpcpp,
+ "//test/packetimpact/proto:posix_server_cc_grpc_proto",
+ "//test/packetimpact/proto:posix_server_cc_proto",
+ ],
+)
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
new file mode 100644
index 000000000..2f10dda40
--- /dev/null
+++ b/test/packetimpact/dut/posix_server.cc
@@ -0,0 +1,229 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <getopt.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <unordered_map>
+
+#include "arpa/inet.h"
+#include "include/grpcpp/security/server_credentials.h"
+#include "include/grpcpp/server_builder.h"
+#include "test/packetimpact/proto/posix_server.grpc.pb.h"
+#include "test/packetimpact/proto/posix_server.pb.h"
+
+// Converts a sockaddr_storage to a Sockaddr message.
+::grpc::Status sockaddr_to_proto(const sockaddr_storage &addr,
+ socklen_t addrlen,
+ posix_server::Sockaddr *sockaddr_proto) {
+ switch (addr.ss_family) {
+ case AF_INET: {
+ auto addr_in = reinterpret_cast<const sockaddr_in *>(&addr);
+ auto response_in = sockaddr_proto->mutable_in();
+ response_in->set_family(addr_in->sin_family);
+ response_in->set_port(ntohs(addr_in->sin_port));
+ response_in->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in->sin_addr.s_addr), 4);
+ return ::grpc::Status::OK;
+ }
+ case AF_INET6: {
+ auto addr_in6 = reinterpret_cast<const sockaddr_in6 *>(&addr);
+ auto response_in6 = sockaddr_proto->mutable_in6();
+ response_in6->set_family(addr_in6->sin6_family);
+ response_in6->set_port(ntohs(addr_in6->sin6_port));
+ response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo));
+ response_in6->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ response_in6->set_scope_id(ntohl(addr_in6->sin6_scope_id));
+ return ::grpc::Status::OK;
+ }
+ }
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr");
+}
+
+class PosixImpl final : public posix_server::Posix::Service {
+ ::grpc::Status Socket(grpc_impl::ServerContext *context,
+ const ::posix_server::SocketRequest *request,
+ ::posix_server::SocketResponse *response) override {
+ response->set_fd(
+ socket(request->domain(), request->type(), request->protocol()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Bind(grpc_impl::ServerContext *context,
+ const ::posix_server::BindRequest *request,
+ ::posix_server::BindResponse *response) override {
+ if (!request->has_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+
+ switch (request->addr().sockaddr_case()) {
+ case posix_server::Sockaddr::SockaddrCase::kIn: {
+ auto request_in = request->addr().in();
+ if (request_in.addr().size() != 4) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv4 address must be 4 bytes");
+ }
+ auto addr_in = reinterpret_cast<sockaddr_in *>(&addr);
+ addr_in->sin_family = request_in.family();
+ addr_in->sin_port = htons(request_in.port());
+ request_in.addr().copy(
+ reinterpret_cast<char *>(&addr_in->sin_addr.s_addr), 4);
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::kIn6: {
+ auto request_in6 = request->addr().in6();
+ if (request_in6.addr().size() != 16) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv6 address must be 16 bytes");
+ }
+ auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(&addr);
+ addr_in6->sin6_family = request_in6.family();
+ addr_in6->sin6_port = htons(request_in6.port());
+ addr_in6->sin6_flowinfo = htonl(request_in6.flowinfo());
+ request_in6.addr().copy(
+ reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ addr_in6->sin6_scope_id = htonl(request_in6.scope_id());
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown Sockaddr");
+ }
+ response->set_ret(bind(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status GetSockName(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::GetSockNameRequest *request,
+ ::posix_server::GetSockNameResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_ret(getsockname(
+ request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status Listen(grpc_impl::ServerContext *context,
+ const ::posix_server::ListenRequest *request,
+ ::posix_server::ListenResponse *response) override {
+ response->set_ret(listen(request->sockfd(), request->backlog()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Accept(grpc_impl::ServerContext *context,
+ const ::posix_server::AcceptRequest *request,
+ ::posix_server::AcceptResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_fd(accept(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status SetSockOpt(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::SetSockOptRequest *request,
+ ::posix_server::SetSockOptResponse *response) override {
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(), request->optval().c_str(),
+ request->optval().size()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SetSockOptTimeval(
+ ::grpc::ServerContext *context,
+ const ::posix_server::SetSockOptTimevalRequest *request,
+ ::posix_server::SetSockOptTimevalResponse *response) override {
+ timeval tv = {.tv_sec = static_cast<__time_t>(request->timeval().seconds()),
+ .tv_usec = static_cast<__suseconds_t>(
+ request->timeval().microseconds())};
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(), &tv, sizeof(tv)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Close(grpc_impl::ServerContext *context,
+ const ::posix_server::CloseRequest *request,
+ ::posix_server::CloseResponse *response) override {
+ response->set_ret(close(request->fd()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+};
+
+// Parse command line options. Returns a pointer to the first argument beyond
+// the options.
+void parse_command_line_options(int argc, char *argv[], std::string *ip,
+ int *port) {
+ static struct option options[] = {{"ip", required_argument, NULL, 1},
+ {"port", required_argument, NULL, 2},
+ {0, 0, 0, 0}};
+
+ // Parse the arguments.
+ int c;
+ while ((c = getopt_long(argc, argv, "", options, NULL)) > 0) {
+ if (c == 1) {
+ *ip = optarg;
+ } else if (c == 2) {
+ *port = std::stoi(std::string(optarg));
+ }
+ }
+}
+
+void run_server(const std::string &ip, int port) {
+ PosixImpl posix_service;
+ grpc::ServerBuilder builder;
+ std::string server_address = ip + ":" + std::to_string(port);
+ // Set the authentication mechanism.
+ std::shared_ptr<grpc::ServerCredentials> creds =
+ grpc::InsecureServerCredentials();
+ builder.AddListeningPort(server_address, creds);
+ builder.RegisterService(&posix_service);
+
+ std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
+ std::cerr << "Server listening on " << server_address << std::endl;
+ server->Wait();
+ std::cerr << "posix_server is finished." << std::endl;
+}
+
+int main(int argc, char *argv[]) {
+ std::cerr << "posix_server is starting." << std::endl;
+ std::string ip;
+ int port;
+ parse_command_line_options(argc, argv, &ip, &port);
+
+ std::cerr << "Got IP " << ip << " and port " << port << "." << std::endl;
+ run_server(ip, port);
+}
diff --git a/test/packetimpact/proto/BUILD b/test/packetimpact/proto/BUILD
new file mode 100644
index 000000000..4a4370f42
--- /dev/null
+++ b/test/packetimpact/proto/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "proto_library")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+proto_library(
+ name = "posix_server",
+ srcs = ["posix_server.proto"],
+ has_services = 1,
+)
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
new file mode 100644
index 000000000..026876fc2
--- /dev/null
+++ b/test/packetimpact/proto/posix_server.proto
@@ -0,0 +1,150 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package posix_server;
+
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2;
+}
+
+message SockaddrIn {
+ int32 family = 1;
+ uint32 port = 2;
+ bytes addr = 3;
+}
+
+message SockaddrIn6 {
+ uint32 family = 1;
+ uint32 port = 2;
+ uint32 flowinfo = 3;
+ bytes addr = 4;
+ uint32 scope_id = 5;
+}
+
+message Sockaddr {
+ oneof sockaddr {
+ SockaddrIn in = 1;
+ SockaddrIn6 in6 = 2;
+ }
+}
+
+message BindRequest {
+ int32 sockfd = 1;
+ Sockaddr addr = 2;
+}
+
+message BindResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message GetSockNameRequest {
+ int32 sockfd = 1;
+}
+
+message GetSockNameResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+ Sockaddr addr = 3;
+}
+
+message ListenRequest {
+ int32 sockfd = 1;
+ int32 backlog = 2;
+}
+
+message ListenResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message AcceptRequest {
+ int32 sockfd = 1;
+}
+
+message AcceptResponse {
+ int32 fd = 1;
+ int32 errno_ = 2;
+ Sockaddr addr = 3;
+}
+
+message SetSockOptRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ bytes optval = 4;
+}
+
+message SetSockOptResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message Timeval {
+ int64 seconds = 1;
+ int64 microseconds = 2;
+}
+
+message SetSockOptTimevalRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ Timeval timeval = 4;
+}
+
+message SetSockOptTimevalResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message CloseRequest {
+ int32 fd = 1;
+}
+
+message CloseResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+service Posix {
+ // Call socket() on the DUT.
+ rpc Socket(SocketRequest) returns (SocketResponse);
+ // Call bind() on the DUT.
+ rpc Bind(BindRequest) returns (BindResponse);
+ // Call getsockname() on the DUT.
+ rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
+ // Call listen() on the DUT.
+ rpc Listen(ListenRequest) returns (ListenResponse);
+ // Call accept() on the DUT.
+ rpc Accept(AcceptRequest) returns (AcceptResponse);
+ // Call setsockopt() on the DUT. You should prefer one of the other
+ // SetSockOpt* functions with a more structured optval or else you may get the
+ // encoding wrong, such as making a bad assumption about the server's word
+ // sizes or endianness.
+ rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse);
+ // Call setsockopt() on the DUT with a Timeval optval.
+ rpc SetSockOptTimeval(SetSockOptTimevalRequest)
+ returns (SetSockOptTimevalResponse);
+ // Call close() on the DUT.
+ rpc Close(CloseRequest) returns (CloseResponse);
+}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
new file mode 100644
index 000000000..a34c81fcc
--- /dev/null
+++ b/test/packetimpact/testbench/BUILD
@@ -0,0 +1,31 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "testbench",
+ srcs = [
+ "connections.go",
+ "dut.go",
+ "dut_client.go",
+ "layers.go",
+ "rawsockets.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//pkg/usermem",
+ "//test/packetimpact/proto:posix_server_go_proto",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "@com_github_imdario_mergo//:go_default_library",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ "@org_golang_google_grpc//keepalive:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
new file mode 100644
index 000000000..b7aa63934
--- /dev/null
+++ b/test/packetimpact/testbench/connections.go
@@ -0,0 +1,245 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testbench has utilities to send and receive packets and also command
+// the DUT to run POSIX functions.
+package testbench
+
+import (
+ "flag"
+ "fmt"
+ "math/rand"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/mohae/deepcopy"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+var localIPv4 = flag.String("local_ipv4", "", "local IPv4 address for test packets")
+var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test packets")
+var localMAC = flag.String("local_mac", "", "local mac address for test packets")
+var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
+
+// TCPIPv4 maintains state about a TCP/IPv4 connection.
+type TCPIPv4 struct {
+ outgoing Layers
+ incoming Layers
+ LocalSeqNum seqnum.Value
+ RemoteSeqNum seqnum.Value
+ SynAck *TCP
+ sniffer Sniffer
+ injector Injector
+ portPickerFD int
+ t *testing.T
+}
+
+// pickPort makes a new socket and returns the socket FD and port. The caller
+// must close the FD when done with the port if there is no error.
+func pickPort() (int, uint16, error) {
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0)
+ if err != nil {
+ return -1, 0, err
+ }
+ var sa unix.SockaddrInet4
+ copy(sa.Addr[0:4], net.ParseIP(*localIPv4).To4())
+ if err := unix.Bind(fd, &sa); err != nil {
+ unix.Close(fd)
+ return -1, 0, err
+ }
+ newSockAddr, err := unix.Getsockname(fd)
+ if err != nil {
+ unix.Close(fd)
+ return -1, 0, err
+ }
+ newSockAddrInet4, ok := newSockAddr.(*unix.SockaddrInet4)
+ if !ok {
+ unix.Close(fd)
+ return -1, 0, fmt.Errorf("can't cast Getsockname result to SockaddrInet4")
+ }
+ return fd, uint16(newSockAddrInet4.Port), nil
+}
+
+// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It
+// is the third, after Ethernet and IPv4.
+const tcpLayerIndex int = 2
+
+// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
+func NewTCPIPv4(t *testing.T, dut DUT, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ lMAC, err := tcpip.ParseMACAddress(*localMAC)
+ if err != nil {
+ t.Fatalf("can't parse localMAC %q: %s", *localMAC, err)
+ }
+
+ rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
+ if err != nil {
+ t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err)
+ }
+
+ portPickerFD, localPort, err := pickPort()
+ if err != nil {
+ t.Fatalf("can't pick a port: %s", err)
+ }
+ lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
+ rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
+
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make new sniffer: %s", err)
+ }
+
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make new injector: %s", err)
+ }
+
+ newOutgoingTCP := &TCP{
+ DataOffset: Uint8(header.TCPMinimumSize),
+ WindowSize: Uint16(32768),
+ SrcPort: &localPort,
+ }
+ if err := newOutgoingTCP.merge(outgoingTCP); err != nil {
+ t.Fatalf("can't merge %v into %v: %s", outgoingTCP, newOutgoingTCP, err)
+ }
+ newIncomingTCP := &TCP{
+ DstPort: &localPort,
+ }
+ if err := newIncomingTCP.merge(incomingTCP); err != nil {
+ t.Fatalf("can't merge %v into %v: %s", incomingTCP, newIncomingTCP, err)
+ }
+ return TCPIPv4{
+ outgoing: Layers{
+ &Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
+ &IPv4{SrcAddr: &lIP, DstAddr: &rIP},
+ newOutgoingTCP},
+ incoming: Layers{
+ &Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
+ &IPv4{SrcAddr: &rIP, DstAddr: &lIP},
+ newIncomingTCP},
+ sniffer: sniffer,
+ injector: injector,
+ portPickerFD: portPickerFD,
+ t: t,
+ LocalSeqNum: seqnum.Value(rand.Uint32()),
+ }
+}
+
+// Close the injector and sniffer associated with this connection.
+func (conn *TCPIPv4) Close() {
+ conn.sniffer.Close()
+ conn.injector.Close()
+ if err := unix.Close(conn.portPickerFD); err != nil {
+ conn.t.Fatalf("can't close portPickerFD: %s", err)
+ }
+ conn.portPickerFD = -1
+}
+
+// Send a packet with reasonable defaults and override some fields by tcp.
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ if tcp.SeqNum == nil {
+ tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum))
+ }
+ if tcp.AckNum == nil {
+ tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum))
+ }
+ layersToSend := deepcopy.Copy(conn.outgoing).(Layers)
+ if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil {
+ conn.t.Fatalf("can't merge %v into %v: %s", tcp, layersToSend[tcpLayerIndex], err)
+ }
+ layersToSend = append(layersToSend, additionalLayers...)
+ outBytes, err := layersToSend.toBytes()
+ if err != nil {
+ conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
+ }
+ conn.injector.Send(outBytes)
+
+ // Compute the next TCP sequence number.
+ for i := tcpLayerIndex + 1; i < len(layersToSend); i++ {
+ conn.LocalSeqNum.UpdateForward(seqnum.Size(layersToSend[i].length()))
+ }
+ if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ conn.LocalSeqNum.UpdateForward(1)
+ }
+}
+
+// Recv gets a packet from the sniffer within the timeout provided. If no packet
+// arrives before the timeout, it returns nil.
+func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
+ deadline := time.Now().Add(timeout)
+ for {
+ timeout = deadline.Sub(time.Now())
+ if timeout <= 0 {
+ break
+ }
+ b := conn.sniffer.Recv(timeout)
+ if b == nil {
+ break
+ }
+ layers, err := ParseEther(b)
+ if err != nil {
+ continue // Ignore packets that can't be parsed.
+ }
+ if !conn.incoming.match(layers) {
+ continue // Ignore packets that don't match the expected incoming.
+ }
+ tcpHeader := (layers[tcpLayerIndex]).(*TCP)
+ conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum)
+ if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ conn.RemoteSeqNum.UpdateForward(1)
+ }
+ for i := tcpLayerIndex + 1; i < len(layers); i++ {
+ conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length()))
+ }
+ return tcpHeader
+ }
+ return nil
+}
+
+// Expect a packet that matches the provided tcp within the timeout specified.
+// If it doesn't arrive in time, the test fails.
+func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP {
+ deadline := time.Now().Add(timeout)
+ for {
+ timeout = deadline.Sub(time.Now())
+ if timeout <= 0 {
+ return nil
+ }
+ gotTCP := conn.Recv(timeout)
+ if gotTCP == nil {
+ return nil
+ }
+ if tcp.match(gotTCP) {
+ return gotTCP
+ }
+ }
+}
+
+// Handshake performs a TCP 3-way handshake.
+func (conn *TCPIPv4) Handshake() {
+ // Send the SYN.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
+
+ // Wait for the SYN-ACK.
+ conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if conn.SynAck == nil {
+ conn.t.Fatalf("didn't get synack during handshake")
+ }
+
+ // Send an ACK.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
new file mode 100644
index 000000000..8ea1706d3
--- /dev/null
+++ b/test/packetimpact/testbench/dut.go
@@ -0,0 +1,363 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "context"
+ "flag"
+ "net"
+ "strconv"
+ "syscall"
+ "testing"
+ "time"
+
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+
+ "golang.org/x/sys/unix"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/keepalive"
+)
+
+var (
+ posixServerIP = flag.String("posix_server_ip", "", "ip address to listen to for UDP commands")
+ posixServerPort = flag.Int("posix_server_port", 40000, "port to listen to for UDP commands")
+ rpcTimeout = flag.Duration("rpc_timeout", 100*time.Millisecond, "gRPC timeout")
+ rpcKeepalive = flag.Duration("rpc_keepalive", 10*time.Second, "gRPC keepalive")
+)
+
+// DUT communicates with the DUT to force it to make POSIX calls.
+type DUT struct {
+ t *testing.T
+ conn *grpc.ClientConn
+ posixServer PosixClient
+}
+
+// NewDUT creates a new connection with the DUT over gRPC.
+func NewDUT(t *testing.T) DUT {
+ flag.Parse()
+ posixServerAddress := *posixServerIP + ":" + strconv.Itoa(*posixServerPort)
+ conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: *rpcKeepalive}))
+ if err != nil {
+ t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err)
+ }
+ posixServer := NewPosixClient(conn)
+ return DUT{
+ t: t,
+ conn: conn,
+ posixServer: posixServer,
+ }
+}
+
+// TearDown closes the underlying connection.
+func (dut *DUT) TearDown() {
+ dut.conn.Close()
+}
+
+// SocketWithErrno calls socket on the DUT and returns the fd and errno.
+func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SocketRequest{
+ Domain: domain,
+ Type: typ,
+ Protocol: proto,
+ }
+ ctx := context.Background()
+ resp, err := dut.posixServer.Socket(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Socket: %s", err)
+ }
+ return resp.GetFd(), syscall.Errno(resp.GetErrno_())
+}
+
+// Socket calls socket on the DUT and returns the file descriptor. If socket
+// fails on the DUT, the test ends.
+func (dut *DUT) Socket(domain, typ, proto int32) int32 {
+ dut.t.Helper()
+ fd, err := dut.SocketWithErrno(domain, typ, proto)
+ if fd < 0 {
+ dut.t.Fatalf("failed to create socket: %s", err)
+ }
+ return fd
+}
+
+func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
+ dut.t.Helper()
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In{
+ In: &pb.SockaddrIn{
+ Family: unix.AF_INET,
+ Port: uint32(s.Port),
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ case *unix.SockaddrInet6:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In6{
+ In6: &pb.SockaddrIn6{
+ Family: unix.AF_INET6,
+ Port: uint32(s.Port),
+ Flowinfo: 0,
+ ScopeId: s.ZoneId,
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ }
+ dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ return nil
+}
+
+func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
+ dut.t.Helper()
+ switch s := sa.Sockaddr.(type) {
+ case *pb.Sockaddr_In:
+ ret := unix.SockaddrInet4{
+ Port: int(s.In.GetPort()),
+ }
+ copy(ret.Addr[:], s.In.GetAddr())
+ return &ret
+ case *pb.Sockaddr_In6:
+ ret := unix.SockaddrInet6{
+ Port: int(s.In6.GetPort()),
+ ZoneId: s.In6.GetScopeId(),
+ }
+ copy(ret.Addr[:], s.In6.GetAddr())
+ }
+ dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ return nil
+}
+
+// BindWithErrno calls bind on the DUT.
+func (dut *DUT) BindWithErrno(fd int32, sa unix.Sockaddr) (int32, error) {
+ dut.t.Helper()
+ req := pb.BindRequest{
+ Sockfd: fd,
+ Addr: dut.sockaddrToProto(sa),
+ }
+ ctx := context.Background()
+ resp, err := dut.posixServer.Bind(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
+// succeed.
+func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ret, err := dut.BindWithErrno(fd, sa)
+ if ret != 0 {
+ dut.t.Fatalf("failed to bind socket: %s", err)
+ }
+}
+
+// GetSockNameWithErrno calls getsockname on the DUT.
+func (dut *DUT) GetSockNameWithErrno(sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.GetSockNameRequest{
+ Sockfd: sockfd,
+ }
+ ctx := context.Background()
+ resp, err := dut.posixServer.GetSockName(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockName calls getsockname on the DUT and causes a fatal test failure if
+// it doens't succeed.
+func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
+ dut.t.Helper()
+ ret, sa, err := dut.GetSockNameWithErrno(sockfd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to getsockname: %s", err)
+ }
+ return sa
+}
+
+// ListenWithErrno calls listen on the DUT.
+func (dut *DUT) ListenWithErrno(sockfd, backlog int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.ListenRequest{
+ Sockfd: sockfd,
+ Backlog: backlog,
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ resp, err := dut.posixServer.Listen(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Listen: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// succeed.
+func (dut *DUT) Listen(sockfd, backlog int32) {
+ dut.t.Helper()
+ ret, err := dut.ListenWithErrno(sockfd, backlog)
+ if ret != 0 {
+ dut.t.Fatalf("failed to listen: %s", err)
+ }
+}
+
+// AcceptWithErrno calls accept on the DUT.
+func (dut *DUT) AcceptWithErrno(sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.AcceptRequest{
+ Sockfd: sockfd,
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ resp, err := dut.posixServer.Accept(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Accept: %s", err)
+ }
+ return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
+// succeed.
+func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+ dut.t.Helper()
+ fd, sa, err := dut.AcceptWithErrno(sockfd)
+ if fd < 0 {
+ dut.t.Fatalf("failed to accept: %s", err)
+ }
+ return fd, sa
+}
+
+// SetSockOptWithErrno calls setsockopt on the DUT.
+func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte) (int32, error) {
+ dut.t.Helper()
+ req := pb.SetSockOptRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Optval: optval,
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ resp, err := dut.posixServer.SetSockOpt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOpt: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed.
+func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
+ dut.t.Helper()
+ ret, err := dut.SetSockOptWithErrno(sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ }
+}
+
+// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
+// bytes.
+func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+ dut.t.Helper()
+ timeval := pb.Timeval{
+ Seconds: int64(tv.Sec),
+ Microseconds: int64(tv.Usec),
+ }
+ req := pb.SetSockOptTimevalRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Timeval: &timeval,
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed.
+func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
+ dut.t.Helper()
+ ret, err := dut.SetSockOptTimevalWithErrno(sockfd, level, optname, tv)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ }
+}
+
+// CloseWithErrno calls close on the DUT.
+func (dut *DUT) CloseWithErrno(fd int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.CloseRequest{
+ Fd: fd,
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ resp, err := dut.posixServer.Close(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Close: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// succeed.
+func (dut *DUT) Close(fd int32) {
+ dut.t.Helper()
+ ret, err := dut.CloseWithErrno(fd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to close: %s", err)
+ }
+}
+
+// CreateListener makes a new TCP connection. If it fails, the test ends.
+func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+ dut.t.Helper()
+ addr := net.ParseIP(*remoteIPv4)
+ var fd int32
+ if addr.To4() != nil {
+ fd = dut.Socket(unix.AF_INET, typ, proto)
+ sa := unix.SockaddrInet4{}
+ copy(sa.Addr[:], addr.To4())
+ dut.Bind(fd, &sa)
+ } else if addr.To16() != nil {
+ fd = dut.Socket(unix.AF_INET6, typ, proto)
+ sa := unix.SockaddrInet6{}
+ copy(sa.Addr[:], addr.To16())
+ dut.Bind(fd, &sa)
+ } else {
+ dut.t.Fatal("unknown ip addr type for remoteIP")
+ }
+ sa := dut.GetSockName(fd)
+ var port int
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ port = s.Port
+ case *unix.SockaddrInet6:
+ port = s.Port
+ default:
+ dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ }
+ dut.Listen(fd, backlog)
+ return fd, uint16(port)
+}
diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go
new file mode 100644
index 000000000..b130a33a2
--- /dev/null
+++ b/test/packetimpact/testbench/dut_client.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "google.golang.org/grpc"
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+)
+
+// PosixClient is a gRPC client for the Posix service.
+type PosixClient pb.PosixClient
+
+// NewPosixClient makes a new gRPC client for the Posix service.
+func NewPosixClient(c grpc.ClientConnInterface) PosixClient {
+ return pb.NewPosixClient(c)
+}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
new file mode 100644
index 000000000..35fa4dcb6
--- /dev/null
+++ b/test/packetimpact/testbench/layers.go
@@ -0,0 +1,507 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "fmt"
+ "reflect"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "github.com/imdario/mergo"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ // toBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ toBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+}
+
+// LayerBase is the common elements of all layers.
+type LayerBase struct {
+ nextLayer Layer
+ prevLayer Layer
+}
+
+func (lb *LayerBase) next() Layer {
+ return lb.nextLayer
+}
+
+func (lb *LayerBase) prev() Layer {
+ return lb.prevLayer
+}
+
+func (lb *LayerBase) setNext(l Layer) {
+ lb.nextLayer = l
+}
+
+func (lb *LayerBase) setPrev(l Layer) {
+ lb.prevLayer = l
+}
+
+func equalLayer(x, y Layer) bool {
+ opt := cmp.FilterValues(func(x, y interface{}) bool {
+ if reflect.ValueOf(x).Kind() == reflect.Ptr && reflect.ValueOf(x).IsNil() {
+ return true
+ }
+ if reflect.ValueOf(y).Kind() == reflect.Ptr && reflect.ValueOf(y).IsNil() {
+ return true
+ }
+ return false
+
+ }, cmp.Ignore())
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{}))
+}
+
+// Ether can construct and match the ethernet encapsulation.
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+
+func (l *Ether) toBytes() ([]byte, error) {
+ b := make([]byte, header.EthernetMinimumSize)
+ h := header.Ethernet(b)
+ fields := &header.EthernetFields{}
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Type != nil {
+ fields.Type = *l.Type
+ } else {
+ switch n := l.next().(type) {
+ case *IPv4:
+ fields.Type = header.IPv4ProtocolNumber
+ default:
+ // TODO(b/150301488): Support more protocols, like IPv6.
+ return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n)
+ }
+ }
+ h.Encode(fields)
+ return h, nil
+}
+
+// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value
+// to store v and returns a pointer to it.
+func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress {
+ return &v
+}
+
+// NetworkProtocolNumber is a helper routine that allocates a new
+// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it.
+func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber {
+ return &v
+}
+
+// ParseEther parses the bytes assuming that they start with an ethernet header
+// and continues parsing further encapsulations.
+func ParseEther(b []byte) (Layers, error) {
+ h := header.Ethernet(b)
+ ether := Ether{
+ SrcAddr: LinkAddress(h.SourceAddress()),
+ DstAddr: LinkAddress(h.DestinationAddress()),
+ Type: NetworkProtocolNumber(h.Type()),
+ }
+ layers := Layers{&ether}
+ switch h.Type() {
+ case header.IPv4ProtocolNumber:
+ moreLayers, err := ParseIPv4(b[ether.length():])
+ if err != nil {
+ return nil, err
+ }
+ return append(layers, moreLayers...), nil
+ default:
+ // TODO(b/150301488): Support more protocols, like IPv6.
+ return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %v", b)
+ }
+}
+
+func (l *Ether) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Ether) length() int {
+ return header.EthernetMinimumSize
+}
+
+// IPv4 can construct and match the ethernet excapulation.
+type IPv4 struct {
+ LayerBase
+ IHL *uint8
+ TOS *uint8
+ TotalLength *uint16
+ ID *uint16
+ Flags *uint8
+ FragmentOffset *uint16
+ TTL *uint8
+ Protocol *uint8
+ Checksum *uint16
+ SrcAddr *tcpip.Address
+ DstAddr *tcpip.Address
+}
+
+func (l *IPv4) toBytes() ([]byte, error) {
+ b := make([]byte, header.IPv4MinimumSize)
+ h := header.IPv4(b)
+ fields := &header.IPv4Fields{
+ IHL: 20,
+ TOS: 0,
+ TotalLength: 0,
+ ID: 0,
+ Flags: 0,
+ FragmentOffset: 0,
+ TTL: 64,
+ Protocol: 0,
+ Checksum: 0,
+ SrcAddr: tcpip.Address(""),
+ DstAddr: tcpip.Address(""),
+ }
+ if l.TOS != nil {
+ fields.TOS = *l.TOS
+ }
+ if l.TotalLength != nil {
+ fields.TotalLength = *l.TotalLength
+ } else {
+ fields.TotalLength = uint16(l.length())
+ current := l.next()
+ for current != nil {
+ fields.TotalLength += uint16(current.length())
+ current = current.next()
+ }
+ }
+ if l.ID != nil {
+ fields.ID = *l.ID
+ }
+ if l.Flags != nil {
+ fields.Flags = *l.Flags
+ }
+ if l.FragmentOffset != nil {
+ fields.FragmentOffset = *l.FragmentOffset
+ }
+ if l.TTL != nil {
+ fields.TTL = *l.TTL
+ }
+ if l.Protocol != nil {
+ fields.Protocol = *l.Protocol
+ } else {
+ switch n := l.next().(type) {
+ case *TCP:
+ fields.Protocol = uint8(header.TCPProtocolNumber)
+ default:
+ // TODO(b/150301488): Support more protocols, like UDP.
+ return nil, fmt.Errorf("can't deduce the ip header's next protocol: %+v", n)
+ }
+ }
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Checksum != nil {
+ fields.Checksum = *l.Checksum
+ }
+ h.Encode(fields)
+ if l.Checksum == nil {
+ h.SetChecksum(^h.CalculateChecksum())
+ }
+ return h, nil
+}
+
+// Uint16 is a helper routine that allocates a new
+// uint16 value to store v and returns a pointer to it.
+func Uint16(v uint16) *uint16 {
+ return &v
+}
+
+// Uint8 is a helper routine that allocates a new
+// uint8 value to store v and returns a pointer to it.
+func Uint8(v uint8) *uint8 {
+ return &v
+}
+
+// Address is a helper routine that allocates a new tcpip.Address value to store
+// v and returns a pointer to it.
+func Address(v tcpip.Address) *tcpip.Address {
+ return &v
+}
+
+// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and
+// continues parsing further encapsulations.
+func ParseIPv4(b []byte) (Layers, error) {
+ h := header.IPv4(b)
+ tos, _ := h.TOS()
+ ipv4 := IPv4{
+ IHL: Uint8(h.HeaderLength()),
+ TOS: &tos,
+ TotalLength: Uint16(h.TotalLength()),
+ ID: Uint16(h.ID()),
+ Flags: Uint8(h.Flags()),
+ FragmentOffset: Uint16(h.FragmentOffset()),
+ TTL: Uint8(h.TTL()),
+ Protocol: Uint8(h.Protocol()),
+ Checksum: Uint16(h.Checksum()),
+ SrcAddr: Address(h.SourceAddress()),
+ DstAddr: Address(h.DestinationAddress()),
+ }
+ layers := Layers{&ipv4}
+ switch h.Protocol() {
+ case uint8(header.TCPProtocolNumber):
+ moreLayers, err := ParseTCP(b[ipv4.length():])
+ if err != nil {
+ return nil, err
+ }
+ return append(layers, moreLayers...), nil
+ }
+ return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol())
+}
+
+func (l *IPv4) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *IPv4) length() int {
+ if l.IHL == nil {
+ return header.IPv4MinimumSize
+ }
+ return int(*l.IHL)
+}
+
+// TCP can construct and match the TCP excapulation.
+type TCP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ SeqNum *uint32
+ AckNum *uint32
+ DataOffset *uint8
+ Flags *uint8
+ WindowSize *uint16
+ Checksum *uint16
+ UrgentPointer *uint16
+}
+
+func (l *TCP) toBytes() ([]byte, error) {
+ b := make([]byte, header.TCPMinimumSize)
+ h := header.TCP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.SeqNum != nil {
+ h.SetSequenceNumber(*l.SeqNum)
+ }
+ if l.AckNum != nil {
+ h.SetAckNumber(*l.AckNum)
+ }
+ if l.DataOffset != nil {
+ h.SetDataOffset(*l.DataOffset)
+ }
+ if l.Flags != nil {
+ h.SetFlags(*l.Flags)
+ }
+ if l.WindowSize != nil {
+ h.SetWindowSize(*l.WindowSize)
+ }
+ if l.UrgentPointer != nil {
+ h.SetUrgentPoiner(*l.UrgentPointer)
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// setChecksum calculates the checksum of the TCP header and sets it in h.
+func setChecksum(h *header.TCP, tcp *TCP) error {
+ h.SetChecksum(0)
+ tcpLength := uint16(tcp.length())
+ current := tcp.next()
+ for current != nil {
+ tcpLength += uint16(current.length())
+ current = current.next()
+ }
+
+ var xsum uint16
+ switch s := tcp.prev().(type) {
+ case *IPv4:
+ xsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, *s.SrcAddr, *s.DstAddr, tcpLength)
+ default:
+ // TODO(b/150301488): Support more protocols, like IPv6.
+ return fmt.Errorf("can't get src and dst addr from previous layer")
+ }
+ current = tcp.next()
+ for current != nil {
+ payload, err := current.toBytes()
+ if err != nil {
+ return fmt.Errorf("can't get bytes for next header: %s", payload)
+ }
+ xsum = header.Checksum(payload, xsum)
+ current = current.next()
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// Uint32 is a helper routine that allocates a new
+// uint32 value to store v and returns a pointer to it.
+func Uint32(v uint32) *uint32 {
+ return &v
+}
+
+// ParseTCP parses the bytes assuming that they start with a tcp header and
+// continues parsing further encapsulations.
+func ParseTCP(b []byte) (Layers, error) {
+ h := header.TCP(b)
+ tcp := TCP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ SeqNum: Uint32(h.SequenceNumber()),
+ AckNum: Uint32(h.AckNumber()),
+ DataOffset: Uint8(h.DataOffset()),
+ Flags: Uint8(h.Flags()),
+ WindowSize: Uint16(h.WindowSize()),
+ Checksum: Uint16(h.Checksum()),
+ UrgentPointer: Uint16(h.UrgentPointer()),
+ }
+ layers := Layers{&tcp}
+ moreLayers, err := ParsePayload(b[tcp.length():])
+ if err != nil {
+ return nil, err
+ }
+ return append(layers, moreLayers...), nil
+}
+
+func (l *TCP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *TCP) length() int {
+ if l.DataOffset == nil {
+ return header.TCPMinimumSize
+ }
+ return int(*l.DataOffset)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *TCP) merge(other TCP) error {
+ return mergo.Merge(l, other, mergo.WithOverride)
+}
+
+// Payload has bytes beyond OSI layer 4.
+type Payload struct {
+ LayerBase
+ Bytes []byte
+}
+
+// ParsePayload parses the bytes assuming that they start with a payload and
+// continue to the end. There can be no further encapsulations.
+func ParsePayload(b []byte) (Layers, error) {
+ payload := Payload{
+ Bytes: b,
+ }
+ return Layers{&payload}, nil
+}
+
+func (l *Payload) toBytes() ([]byte, error) {
+ return l.Bytes, nil
+}
+
+func (l *Payload) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Payload) length() int {
+ return len(l.Bytes)
+}
+
+// Layers is an array of Layer and supports similar functions to Layer.
+type Layers []Layer
+
+func (ls *Layers) toBytes() ([]byte, error) {
+ for i, l := range *ls {
+ if i > 0 {
+ l.setPrev((*ls)[i-1])
+ }
+ if i+1 < len(*ls) {
+ l.setNext((*ls)[i+1])
+ }
+ }
+ outBytes := []byte{}
+ for _, l := range *ls {
+ layerBytes, err := l.toBytes()
+ if err != nil {
+ return nil, err
+ }
+ outBytes = append(outBytes, layerBytes...)
+ }
+ return outBytes, nil
+}
+
+func (ls *Layers) match(other Layers) bool {
+ if len(*ls) > len(other) {
+ return false
+ }
+ for i := 0; i < len(*ls); i++ {
+ if !equalLayer((*ls)[i], other[i]) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
new file mode 100644
index 000000000..0c7d0f979
--- /dev/null
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -0,0 +1,151 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "encoding/binary"
+ "flag"
+ "math"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var device = flag.String("device", "", "local device for test packets")
+
+// Sniffer can sniff raw packets on the wire.
+type Sniffer struct {
+ t *testing.T
+ fd int
+}
+
+func htons(x uint16) uint16 {
+ buf := [2]byte{}
+ binary.BigEndian.PutUint16(buf[:], x)
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+// NewSniffer creates a Sniffer connected to *device.
+func NewSniffer(t *testing.T) (Sniffer, error) {
+ flag.Parse()
+ snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Sniffer{}, err
+ }
+ return Sniffer{
+ t: t,
+ fd: snifferFd,
+ }, nil
+}
+
+// maxReadSize should be large enough for the maximum frame size in bytes. If a
+// packet too large for the buffer arrives, the test will get a fatal error.
+const maxReadSize int = 65536
+
+// Recv tries to read one frame until the timeout is up.
+func (s *Sniffer) Recv(timeout time.Duration) []byte {
+ deadline := time.Now().Add(timeout)
+ for {
+ timeout = deadline.Sub(time.Now())
+ if timeout <= 0 {
+ return nil
+ }
+ whole, frac := math.Modf(timeout.Seconds())
+ tv := unix.Timeval{
+ Sec: int64(whole),
+ Usec: int64(frac * float64(time.Microsecond/time.Second)),
+ }
+
+ if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
+ s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
+ }
+
+ buf := make([]byte, maxReadSize)
+ nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ // There was a timeout.
+ continue
+ }
+ if err != nil {
+ s.t.Fatalf("can't read: %s", err)
+ }
+ if nread > maxReadSize {
+ s.t.Fatalf("received a truncated frame of %d bytes", nread)
+ }
+ return buf[:nread]
+ }
+}
+
+// Close the socket that Sniffer is using.
+func (s *Sniffer) Close() {
+ if err := unix.Close(s.fd); err != nil {
+ s.t.Fatalf("can't close sniffer socket: %s", err)
+ }
+ s.fd = -1
+}
+
+// Injector can inject raw frames.
+type Injector struct {
+ t *testing.T
+ fd int
+}
+
+// NewInjector creates a new injector on *device.
+func NewInjector(t *testing.T) (Injector, error) {
+ flag.Parse()
+ ifInfo, err := net.InterfaceByName(*device)
+ if err != nil {
+ return Injector{}, err
+ }
+
+ var haddr [8]byte
+ copy(haddr[:], ifInfo.HardwareAddr)
+ sa := unix.SockaddrLinklayer{
+ Protocol: unix.ETH_P_IP,
+ Ifindex: ifInfo.Index,
+ Halen: uint8(len(ifInfo.HardwareAddr)),
+ Addr: haddr,
+ }
+
+ injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Injector{}, err
+ }
+ if err := unix.Bind(injectFd, &sa); err != nil {
+ return Injector{}, err
+ }
+ return Injector{
+ t: t,
+ fd: injectFd,
+ }, nil
+}
+
+// Send a raw frame.
+func (i *Injector) Send(b []byte) {
+ if _, err := unix.Write(i.fd, b); err != nil {
+ i.t.Fatalf("can't write: %s", err)
+ }
+}
+
+// Close the underlying socket.
+func (i *Injector) Close() {
+ if err := unix.Close(i.fd); err != nil {
+ i.t.Fatalf("can't close sniffer socket: %s", err)
+ }
+ i.fd = -1
+}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
new file mode 100644
index 000000000..1dff2a4d5
--- /dev/null
+++ b/test/packetimpact/tests/BUILD
@@ -0,0 +1,21 @@
+load("defs.bzl", "packetimpact_go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+packetimpact_go_test(
+ name = "fin_wait2_timeout",
+ srcs = ["fin_wait2_timeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+sh_binary(
+ name = "test_runner",
+ srcs = ["test_runner.sh"],
+)
diff --git a/test/packetimpact/tests/Dockerfile b/test/packetimpact/tests/Dockerfile
new file mode 100644
index 000000000..507030cc7
--- /dev/null
+++ b/test/packetimpact/tests/Dockerfile
@@ -0,0 +1,5 @@
+FROM ubuntu:bionic
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y iptables netcat tcpdump iproute2 tshark
+RUN hash -r
+CMD /bin/bash
diff --git a/test/packetimpact/tests/defs.bzl b/test/packetimpact/tests/defs.bzl
new file mode 100644
index 000000000..1b4213d9b
--- /dev/null
+++ b/test/packetimpact/tests/defs.bzl
@@ -0,0 +1,106 @@
+"""Defines rules for packetimpact test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def _packetimpact_test_impl(ctx):
+ test_runner = ctx.executable._test_runner
+ bench = ctx.actions.declare_file("%s-bench" % ctx.label.name)
+ bench_content = "\n".join([
+ "#!/bin/bash",
+ # This test will run part in a distinct user namespace. This can cause
+ # permission problems, because all runfiles may not be owned by the
+ # current user, and no other users will be mapped in that namespace.
+ # Make sure that everything is readable here.
+ "find . -type f -exec chmod a+rx {} \\;",
+ "find . -type d -exec chmod a+rx {} \\;",
+ "%s %s --posix_server_binary %s --testbench_binary %s $@\n" % (
+ test_runner.short_path,
+ " ".join(ctx.attr.flags),
+ ctx.files._posix_server_binary[0].short_path,
+ ctx.files.testbench_binary[0].short_path,
+ ),
+ ])
+ ctx.actions.write(bench, bench_content, is_executable = True)
+
+ transitive_files = depset()
+ if hasattr(ctx.attr._test_runner, "data_runfiles"):
+ transitive_files = depset(ctx.attr._test_runner.data_runfiles.files)
+ runfiles = ctx.runfiles(
+ files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary,
+ transitive_files = transitive_files,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = bench, runfiles = runfiles)]
+
+_packetimpact_test = rule(
+ attrs = {
+ "_test_runner": attr.label(
+ executable = True,
+ cfg = "target",
+ default = ":test_runner",
+ ),
+ "_posix_server_binary": attr.label(
+ cfg = "target",
+ default = "//test/packetimpact/dut:posix_server",
+ ),
+ "testbench_binary": attr.label(
+ cfg = "target",
+ mandatory = True,
+ ),
+ "flags": attr.string_list(
+ mandatory = False,
+ default = [],
+ ),
+ },
+ test = True,
+ implementation = _packetimpact_test_impl,
+)
+
+PACKETIMPACT_TAGS = ["local", "manual"]
+
+def packetimpact_linux_test(name, testbench_binary, **kwargs):
+ """Add a packetimpact test on linux.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ _packetimpact_test(
+ name = name + "_linux_test",
+ testbench_binary = testbench_binary,
+ flags = ["--dut_platform", "linux"],
+ tags = PACKETIMPACT_TAGS + ["packetimpact"],
+ **kwargs
+ )
+
+def packetimpact_netstack_test(name, testbench_binary, **kwargs):
+ """Add a packetimpact test on netstack.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ _packetimpact_test(
+ name = name + "_netstack_test",
+ testbench_binary = testbench_binary,
+ # This is the default runtime unless
+ # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
+ flags = ["--dut_platform", "netstack", "--runtime=runsc-d"],
+ tags = PACKETIMPACT_TAGS + ["packetimpact"],
+ **kwargs
+ )
+
+def packetimpact_go_test(name, size = "small", pure = True, **kwargs):
+ testbench_binary = name + "_test"
+ go_test(
+ name = testbench_binary,
+ size = size,
+ pure = pure,
+ tags = PACKETIMPACT_TAGS,
+ **kwargs
+ )
+ packetimpact_linux_test(name = name, testbench_binary = testbench_binary)
+ packetimpact_netstack_test(name = name, testbench_binary = testbench_binary)
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
new file mode 100644
index 000000000..5f54e67ed
--- /dev/null
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -0,0 +1,68 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fin_wait2_timeout_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestFinWait2Timeout(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ linger2 bool
+ }{
+ {"WithLinger2", true},
+ {"WithoutLinger2", false},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, dut, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Handshake()
+
+ acceptFd, _ := dut.Accept(listenFd)
+ if tt.linger2 {
+ tv := unix.Timeval{Sec: 1, Usec: 0}
+ dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv)
+ }
+ dut.Close(acceptFd)
+
+ if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); gotOne == nil {
+ t.Fatal("expected a FIN-ACK within 1 second but got none")
+ }
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+
+ time.Sleep(5 * time.Second)
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if tt.linger2 {
+ if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); gotOne == nil {
+ t.Fatal("expected a RST packet within a second but got none")
+ }
+ } else {
+ if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); gotOne != nil {
+ t.Fatal("expected no RST packets within ten seconds but got one")
+ }
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/test_runner.sh b/test/packetimpact/tests/test_runner.sh
new file mode 100755
index 000000000..5281cb53d
--- /dev/null
+++ b/test/packetimpact/tests/test_runner.sh
@@ -0,0 +1,246 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Run a packetimpact test. Two docker containers are made, one for the
+# Device-Under-Test (DUT) and one for the test bench. Each is attached with
+# two networks, one for control packets that aid the test and one for test
+# packets which are sent as part of the test and observed for correctness.
+
+set -euxo pipefail
+
+function failure() {
+ local lineno=$1
+ local msg=$2
+ local filename="$0"
+ echo "FAIL: $filename:$lineno: $msg"
+}
+trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
+
+declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark"
+
+# Don't use declare below so that the error from getopt will end the script.
+PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
+
+eval set -- "$PARSED"
+
+while true; do
+ case "$1" in
+ --dut_platform)
+ # Either "linux" or "netstack".
+ declare -r DUT_PLATFORM="$2"
+ shift 2
+ ;;
+ --posix_server_binary)
+ declare -r POSIX_SERVER_BINARY="$2"
+ shift 2
+ ;;
+ --testbench_binary)
+ declare -r TESTBENCH_BINARY="$2"
+ shift 2
+ ;;
+ --runtime)
+ # Not readonly because there might be multiple --runtime arguments and we
+ # want to use just the last one. Only used if --dut_platform is
+ # "netstack".
+ declare RUNTIME="$2"
+ shift 2
+ ;;
+ --tshark)
+ declare -r TSHARK="1"
+ shift 1
+ ;;
+ --)
+ shift
+ break
+ ;;
+ *)
+ echo "Programming error"
+ exit 3
+ esac
+done
+
+# All the other arguments are scripts.
+declare -r scripts="$@"
+
+# Check that the required flags are defined in a way that is safe for "set -u".
+if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then
+ if [[ -z "${RUNTIME-}" ]]; then
+ echo "FAIL: Missing --runtime argument: ${RUNTIME-}"
+ exit 2
+ fi
+ declare -r RUNTIME_ARG="--runtime ${RUNTIME}"
+elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then
+ declare -r RUNTIME_ARG=""
+else
+ echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}"
+ exit 2
+fi
+if [[ ! -f "${POSIX_SERVER_BINARY-}" ]]; then
+ echo "FAIL: Bad or missing --posix_server_binary: ${POSIX_SERVER-}"
+ exit 2
+fi
+if [[ ! -f "${TESTBENCH_BINARY-}" ]]; then
+ echo "FAIL: Bad or missing --testbench_binary: ${TESTBENCH_BINARY-}"
+ exit 2
+fi
+
+# Variables specific to the control network and interface start with CTRL_.
+# Variables specific to the test network and interface start with TEST_.
+# Variables specific to the DUT start with DUT_.
+# Variables specific to the test bench start with TESTBENCH_.
+# Use random numbers so that test networks don't collide.
+declare -r CTRL_NET="ctrl_net-${RANDOM}${RANDOM}"
+declare -r TEST_NET="test_net-${RANDOM}${RANDOM}"
+# On both DUT and test bench, testing packets are on the eth2 interface.
+declare -r TEST_DEVICE="eth2"
+# Number of bits in the *_NET_PREFIX variables.
+declare -r NET_MASK="24"
+function new_net_prefix() {
+ # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)"
+}
+# Last bits of the DUT's IP address.
+declare -r DUT_NET_SUFFIX=".10"
+# Control port.
+declare -r CTRL_PORT="40000"
+# Last bits of the test bench's IP address.
+declare -r TESTBENCH_NET_SUFFIX=".20"
+declare -r TIMEOUT="60"
+declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetimpact"
+# Make sure that docker is installed.
+docker --version
+
+function finish {
+ local cleanup_success=1
+ for net in "${CTRL_NET}" "${TEST_NET}"; do
+ # Kill all processes attached to ${net}.
+ for docker_command in "kill" "rm"; do
+ (docker network inspect "${net}" \
+ --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \
+ | xargs -r docker "${docker_command}") || \
+ cleanup_success=0
+ done
+ # Remove the network.
+ docker network rm "${net}" || \
+ cleanup_success=0
+ done
+
+ if ((!$cleanup_success)); then
+ echo "FAIL: Cleanup command failed"
+ exit 4
+ fi
+}
+trap finish EXIT
+
+# Subnet for control packets between test bench and DUT.
+declare CTRL_NET_PREFIX=$(new_net_prefix)
+while ! docker network create \
+ "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do
+ sleep 0.1
+ declare CTRL_NET_PREFIX=$(new_net_prefix)
+done
+
+# Subnet for the packets that are part of the test.
+declare TEST_NET_PREFIX=$(new_net_prefix)
+while ! docker network create \
+ "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do
+ sleep 0.1
+ declare TEST_NET_PREFIX=$(new_net_prefix)
+done
+
+docker pull "${IMAGE_TAG}"
+
+# Create the DUT container and connect to network.
+DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker start "${DUT}"
+
+# Create the test bench container and connect to network.
+TESTBENCH=$(docker create --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \
+ || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \
+ || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false)
+docker start "${TESTBENCH}"
+
+# Start the posix_server in the DUT.
+declare -r DOCKER_POSIX_SERVER_BINARY="/$(basename ${POSIX_SERVER_BINARY})"
+docker cp -L ${POSIX_SERVER_BINARY} "${DUT}:${DOCKER_POSIX_SERVER_BINARY}"
+
+docker exec -t "${DUT}" \
+ /bin/bash -c "${DOCKER_POSIX_SERVER_BINARY} \
+ --ip ${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \
+ --port ${CTRL_PORT}" &
+
+# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will
+# issue a RST. To prevent this IPtables can be used to filter those out.
+docker exec "${TESTBENCH}" \
+ iptables -A INPUT -i ${TEST_DEVICE} -j DROP
+
+# Wait for the DUT server to come up. Attempt to connect to it from the test
+# bench every 100 milliseconds until success.
+while ! docker exec "${TESTBENCH}" \
+ nc -zv "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${CTRL_PORT}"; do
+ sleep 0.1
+done
+
+declare -r REMOTE_MAC=$(docker exec -t "${DUT}" ip link show \
+ "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6)
+declare -r LOCAL_MAC=$(docker exec -t "${TESTBENCH}" ip link show \
+ "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6)
+
+declare -r DOCKER_TESTBENCH_BINARY="/$(basename ${TESTBENCH_BINARY})"
+docker cp -L "${TESTBENCH_BINARY}" "${TESTBENCH}:${DOCKER_TESTBENCH_BINARY}"
+
+if [[ -z "${TSHARK-}" ]]; then
+ # Run tcpdump in the test bench unbuffered, without dns resolution, just on
+ # the interface with the test packets.
+ docker exec -t "${TESTBENCH}" \
+ tcpdump -S -vvv -U -n -i "${TEST_DEVICE}" net "${TEST_NET_PREFIX}/24" &
+else
+ # Run tshark in the test bench unbuffered, without dns resolution, just on the
+ # interface with the test packets.
+ docker exec -t "${TESTBENCH}" \
+ tshark -V -l -n -i "${TEST_DEVICE}" \
+ host "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" &
+fi
+
+# tcpdump and tshark take time to startup
+sleep 3
+
+# Start a packetimpact test on the test bench. The packetimpact test sends and
+# receives packets and also sends POSIX socket commands to the posix_server to
+# be executed on the DUT.
+docker exec -t "${TESTBENCH}" \
+ /bin/bash -c "${DOCKER_TESTBENCH_BINARY} \
+ --posix_server_ip=${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \
+ --posix_server_port=${CTRL_PORT} \
+ --remote_ipv4=${TEST_NET_PREFIX}${DUT_NET_SUFFIX} \
+ --local_ipv4=${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX} \
+ --remote_mac=${REMOTE_MAC} \
+ --local_mac=${LOCAL_MAC} \
+ --device=${TEST_DEVICE}"
+
+echo PASS: No errors.
diff --git a/test/perf/BUILD b/test/perf/BUILD
index 346a28e16..0a0def6a3 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -30,6 +30,7 @@ syscall_test(
syscall_test(
size = "enormous",
+ tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
)
@@ -40,6 +41,7 @@ syscall_test(
syscall_test(
size = "enormous",
+ tags = ["nogotsan"],
test = "//test/perf/linux:gettid_benchmark",
)
diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc
index b349d50bf..241f39896 100644
--- a/test/perf/linux/futex_benchmark.cc
+++ b/test/perf/linux/futex_benchmark.cc
@@ -33,24 +33,24 @@ namespace testing {
namespace {
inline int FutexWait(std::atomic<int32_t>* v, int32_t val) {
- return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, nullptr);
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, nullptr);
}
-inline int FutexWaitRelativeTimeout(std::atomic<int32_t>* v, int32_t val,
- const struct timespec* reltime) {
- return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, reltime);
+inline int FutexWaitMonotonicTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* timeout) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, timeout);
}
-inline int FutexWaitAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
- const struct timespec* abstime) {
- return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, abstime);
+inline int FutexWaitMonotonicDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE, val, deadline,
+ nullptr, FUTEX_BITSET_MATCH_ANY);
}
-inline int FutexWaitBitsetAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
- int32_t bits,
- const struct timespec* abstime) {
+inline int FutexWaitRealtimeDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME,
- val, abstime, nullptr, bits);
+ val, deadline, nullptr, FUTEX_BITSET_MATCH_ANY);
}
inline int FutexWake(std::atomic<int32_t>* v, int32_t count) {
@@ -62,11 +62,11 @@ void BM_FutexWakeNop(benchmark::State& state) {
std::atomic<int32_t> v(0);
for (auto _ : state) {
- EXPECT_EQ(0, FutexWake(&v, 1));
+ TEST_PCHECK(FutexWake(&v, 1) == 0);
}
}
-BENCHMARK(BM_FutexWakeNop);
+BENCHMARK(BM_FutexWakeNop)->MinTime(5);
// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the
// syscall won't wait.
@@ -74,43 +74,63 @@ void BM_FutexWaitNop(benchmark::State& state) {
std::atomic<int32_t> v(0);
for (auto _ : state) {
- EXPECT_EQ(-EAGAIN, FutexWait(&v, 1));
+ TEST_PCHECK(FutexWait(&v, 1) == -1 && errno == EAGAIN);
}
}
-BENCHMARK(BM_FutexWaitNop);
+BENCHMARK(BM_FutexWaitNop)->MinTime(5);
// This uses FUTEX_WAIT with a timeout on an address whose value never
// changes, such that it always times out. Timeout overhead can be estimated by
// timer overruns for short timeouts.
-void BM_FutexWaitTimeout(benchmark::State& state) {
+void BM_FutexWaitMonotonicTimeout(benchmark::State& state) {
const int timeout_ns = state.range(0);
std::atomic<int32_t> v(0);
auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
for (auto _ : state) {
- EXPECT_EQ(-ETIMEDOUT, FutexWaitRelativeTimeout(&v, 0, &ts));
+ TEST_PCHECK(FutexWaitMonotonicTimeout(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
}
}
-BENCHMARK(BM_FutexWaitTimeout)
+BENCHMARK(BM_FutexWaitMonotonicTimeout)
+ ->MinTime(5)
+ ->UseRealTime()
->Arg(1)
->Arg(10)
->Arg(100)
->Arg(1000)
->Arg(10000);
-// This calls FUTEX_WAIT_BITSET with CLOCK_REALTIME.
-void BM_FutexWaitBitset(benchmark::State& state) {
+// This uses FUTEX_WAIT_BITSET with a deadline that is in the past. This allows
+// estimation of the overhead of setting up a timer for a deadline (as opposed
+// to a timeout as specified for FUTEX_WAIT).
+void BM_FutexWaitMonotonicDeadline(benchmark::State& state) {
std::atomic<int32_t> v(0);
- int timeout_ns = state.range(0);
- auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
+ struct timespec ts = {};
+
for (auto _ : state) {
- EXPECT_EQ(-ETIMEDOUT, FutexWaitBitsetAbsoluteTimeout(&v, 0, 1, &ts));
+ TEST_PCHECK(FutexWaitMonotonicDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
}
}
-BENCHMARK(BM_FutexWaitBitset)->Range(0, 100000);
+BENCHMARK(BM_FutexWaitMonotonicDeadline)->MinTime(5);
+
+// This is equivalent to BM_FutexWaitMonotonicDeadline, but uses CLOCK_REALTIME
+// instead of CLOCK_MONOTONIC for the deadline.
+void BM_FutexWaitRealtimeDeadline(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ struct timespec ts = {};
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitRealtimeDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitRealtimeDeadline)->MinTime(5);
int64_t GetCurrentMonotonicTimeNanos() {
struct timespec ts;
@@ -130,11 +150,10 @@ void SpinNanos(int64_t delay_ns) {
// Each iteration of FutexRoundtripDelayed involves a thread sending a futex
// wakeup to another thread, which spins for delay_us and then sends a futex
-// wakeup back. The time per iteration is 2* (delay_us + kBeforeWakeDelayNs +
+// wakeup back. The time per iteration is 2 * (delay_us + kBeforeWakeDelayNs +
// futex/scheduling overhead).
void BM_FutexRoundtripDelayed(benchmark::State& state) {
const int delay_us = state.range(0);
-
const int64_t delay_ns = delay_us * 1000;
// Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce
// the probability that the wakeup comes before the wait, preventing the wait
@@ -165,83 +184,14 @@ void BM_FutexRoundtripDelayed(benchmark::State& state) {
}
BENCHMARK(BM_FutexRoundtripDelayed)
+ ->MinTime(5)
+ ->UseRealTime()
->Arg(0)
->Arg(10)
->Arg(20)
->Arg(50)
->Arg(100);
-// FutexLock is a simple, dumb futex based lock implementation.
-// It will try to acquire the lock by atomically incrementing the
-// lock word. If it did not increment the lock from 0 to 1, someone
-// else has the lock, so it will FUTEX_WAIT until it is woken in
-// the unlock path.
-class FutexLock {
- public:
- FutexLock() : lock_word_(0) {}
-
- void lock(struct timespec* deadline) {
- int32_t val;
- while ((val = lock_word_.fetch_add(1, std::memory_order_acquire) + 1) !=
- 1) {
- // If we didn't get the lock by incrementing from 0 to 1,
- // do a FUTEX_WAIT with the desired current value set to
- // val. If val is no longer what the atomic increment returned,
- // someone might have set it to 0 so we can try to acquire
- // again.
- int ret = FutexWaitAbsoluteTimeout(&lock_word_, val, deadline);
- if (ret == 0 || ret == -EWOULDBLOCK || ret == -EINTR) {
- continue;
- } else {
- FAIL() << "unexpected FUTEX_WAIT return: " << ret;
- }
- }
- }
-
- void unlock() {
- // Store 0 into the lock word and wake one waiter. We intentionally
- // ignore the return value of the FUTEX_WAKE here, since there may be
- // no waiters to wake anyway.
- lock_word_.store(0, std::memory_order_release);
- (void)FutexWake(&lock_word_, 1);
- }
-
- private:
- std::atomic<int32_t> lock_word_;
-};
-
-FutexLock* test_lock; // Used below.
-
-void FutexContend(benchmark::State& state, int thread_index,
- struct timespec* deadline) {
- int counter = 0;
- if (thread_index == 0) {
- test_lock = new FutexLock();
- }
- for (auto _ : state) {
- test_lock->lock(deadline);
- counter++;
- test_lock->unlock();
- }
- if (thread_index == 0) {
- delete test_lock;
- }
- state.SetItemsProcessed(state.iterations());
-}
-
-void BM_FutexContend(benchmark::State& state) {
- FutexContend(state, state.thread_index, nullptr);
-}
-
-BENCHMARK(BM_FutexContend)->ThreadRange(1, 1024)->UseRealTime();
-
-void BM_FutexDeadlineContend(benchmark::State& state) {
- auto deadline = absl::ToTimespec(absl::Now() + absl::Minutes(10));
- FutexContend(state, state.thread_index, &deadline);
-}
-
-BENCHMARK(BM_FutexDeadlineContend)->ThreadRange(1, 1024)->UseRealTime();
-
} // namespace
} // namespace testing
diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc
index a6928df58..cec679191 100644
--- a/test/perf/linux/signal_benchmark.cc
+++ b/test/perf/linux/signal_benchmark.cc
@@ -43,11 +43,13 @@ void BM_FaultSignalFixup(benchmark::State& state) {
// Fault, fault, fault.
for (auto _ : state) {
- register volatile unsigned int* ptr asm("rax");
-
// Trigger the segfault.
- ptr = nullptr;
- *ptr = 0;
+ asm volatile(
+ "movq $0, %%rax\n"
+ "movq $0x77777777, (%%rax)\n"
+ :
+ :
+ : "rax");
}
}
diff --git a/test/root/BUILD b/test/root/BUILD
index 23ce2a70f..ddc9b4955 100644
--- a/test/root/BUILD
+++ b/test/root/BUILD
@@ -16,6 +16,7 @@ go_test(
"crictl_test.go",
"main_test.go",
"oom_score_adj_test.go",
+ "runsc_test.go",
],
data = [
"//runsc",
@@ -37,7 +38,9 @@ go_test(
"//runsc/specutils",
"//runsc/testutil",
"//test/root/testdata",
+ "@com_github_cenkalti_backoff//:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/test/root/runsc_test.go b/test/root/runsc_test.go
new file mode 100644
index 000000000..90373e2db
--- /dev/null
+++ b/test/root/runsc_test.go
@@ -0,0 +1,151 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+// TestDoKill checks that when "runsc do..." is killed, the sandbox process is
+// also terminated. This ensures that parent death signal is propagate to the
+// sandbox process correctly.
+func TestDoKill(t *testing.T) {
+ // Make the sandbox process be reparented here when it's killed, so we can
+ // wait for it.
+ if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil {
+ t.Fatalf("prctl(PR_SET_CHILD_SUBREAPER): %v", err)
+ }
+
+ cmd := exec.Command(specutils.ExePath, "do", "sleep", "10000")
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ cmd.Stderr = buf
+ cmd.Start()
+
+ var pid int
+ findSandbox := func() error {
+ var err error
+ pid, err = sandboxPid(cmd.Process.Pid)
+ if err != nil {
+ return &backoff.PermanentError{Err: err}
+ }
+ if pid == 0 {
+ return fmt.Errorf("sandbox process not found")
+ }
+ return nil
+ }
+ if err := testutil.Poll(findSandbox, 10*time.Second); err != nil {
+ t.Fatalf("failed to find sandbox: %v", err)
+ }
+ t.Logf("Found sandbox, pid: %d", pid)
+
+ if err := cmd.Process.Kill(); err != nil {
+ t.Fatalf("failed to kill run process: %v", err)
+ }
+ cmd.Wait()
+ t.Logf("Parent process killed (%d). Output: %s", cmd.Process.Pid, buf.String())
+
+ ch := make(chan struct{})
+ go func() {
+ defer func() { ch <- struct{}{} }()
+ t.Logf("Waiting for sandbox process (%d) termination", pid)
+ if _, err := unix.Wait4(pid, nil, 0, nil); err != nil {
+ t.Errorf("error waiting for sandbox process (%d): %v", pid, err)
+ }
+ }()
+ select {
+ case <-ch:
+ // Done
+ case <-time.After(5 * time.Second):
+ t.Fatalf("timeout waiting for sandbox process (%d) to exit", pid)
+ }
+}
+
+// sandboxPid looks for the sandbox process inside the process tree starting
+// from "pid". It returns 0 and no error if no sandbox process is found. It
+// returns error if anything failed.
+func sandboxPid(pid int) (int, error) {
+ cmd := exec.Command("pgrep", "-P", strconv.Itoa(pid))
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ if err := cmd.Start(); err != nil {
+ return 0, err
+ }
+ ps, err := cmd.Process.Wait()
+ if err != nil {
+ return 0, err
+ }
+ if ps.ExitCode() == 1 {
+ // pgrep returns 1 when no process is found.
+ return 0, nil
+ }
+
+ var children []int
+ for _, line := range strings.Split(buf.String(), "\n") {
+ if len(line) == 0 {
+ continue
+ }
+ child, err := strconv.Atoi(line)
+ if err != nil {
+ return 0, err
+ }
+
+ cmdline, err := ioutil.ReadFile(filepath.Join("/proc", line, "cmdline"))
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Raced with process exit.
+ continue
+ }
+ return 0, err
+ }
+ args := strings.SplitN(string(cmdline), "\x00", 2)
+ if len(args) == 0 {
+ return 0, fmt.Errorf("malformed cmdline file: %q", cmdline)
+ }
+ // The sandbox process has the first argument set to "runsc-sandbox".
+ if args[0] == "runsc-sandbox" {
+ return child, nil
+ }
+
+ children = append(children, child)
+ }
+
+ // Sandbox process wasn't found, try another level down.
+ for _, pid := range children {
+ sand, err := sandboxPid(pid)
+ if err != nil {
+ return 0, err
+ }
+ if sand != 0 {
+ return sand, nil
+ }
+ // Not found, continue the search.
+ }
+ return 0, nil
+}
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
index f96e2415e..869169ad5 100644
--- a/test/runner/gtest/gtest.go
+++ b/test/runner/gtest/gtest.go
@@ -66,13 +66,12 @@ func (tc TestCase) Args() []string {
}
if tc.benchmark {
return []string{
- fmt.Sprintf("%s=^$", filterTestFlag),
fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name),
+ fmt.Sprintf("%s=", filterTestFlag),
}
}
return []string{
- fmt.Sprintf("%s=^%s$", filterTestFlag, tc.FullName()),
- fmt.Sprintf("%s=^$", filterBenchmarkFlag),
+ fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()),
}
}
@@ -147,6 +146,8 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes
return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr)
}
+ out = []byte(strings.Trim(string(out), "\n"))
+
// Parse benchmark output.
for _, line := range strings.Split(string(out), "\n") {
// Strip comments.
diff --git a/test/runner/runner.go b/test/runner/runner.go
index a78ef38e0..0d3742f71 100644
--- a/test/runner/runner.go
+++ b/test/runner/runner.go
@@ -300,6 +300,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Test spec comes with pre-defined mounts that we don't want. Reset it.
spec.Mounts = nil
+ testTmpDir := "/tmp"
if *useTmpfs {
// Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on
// features only available in gVisor's internal tmpfs may fail.
@@ -325,11 +326,19 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
t.Fatalf("could not chmod temp dir: %v", err)
}
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Type: "bind",
- Destination: "/tmp",
- Source: tmpDir,
- })
+ // "/tmp" is not replaced with a tmpfs mount inside the sandbox
+ // when it's not empty. This ensures that testTmpDir uses gofer
+ // in exclusive mode.
+ testTmpDir = tmpDir
+ if *fileAccess == "shared" {
+ // All external mounts except the root mount are shared.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Type: "bind",
+ Destination: "/tmp",
+ Source: tmpDir,
+ })
+ testTmpDir = "/tmp"
+ }
}
// Set environment variables that indicate we are
@@ -349,12 +358,8 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to
// be backed by tmpfs.
- for i, kv := range env {
- if strings.HasPrefix(kv, "TEST_TMPDIR=") {
- env[i] = "TEST_TMPDIR=/tmp"
- break
- }
- }
+ env = filterEnv(env, []string{"TEST_TMPDIR"})
+ env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir))
spec.Process.Env = env
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 3518e862d..9800a0cdf 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -318,10 +318,14 @@ syscall_test(
test = "//test/syscalls/linux:proc_test",
)
-syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test")
-
syscall_test(test = "//test/syscalls/linux:proc_net_test")
+syscall_test(test = "//test/syscalls/linux:proc_pid_oomscore_test")
+
+syscall_test(test = "//test/syscalls/linux:proc_pid_smaps_test")
+
+syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test")
+
syscall_test(
size = "medium",
test = "//test/syscalls/linux:pselect_test",
@@ -680,6 +684,11 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:tuntap_test")
+syscall_test(
+ add_hostinet = True,
+ test = "//test/syscalls/linux:tuntap_hostinet_test",
+)
+
syscall_test(test = "//test/syscalls/linux:udp_bind_test")
syscall_test(
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 704bae17b..636e5db12 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -167,11 +167,6 @@ cc_library(
)
cc_library(
- name = "temp_umask",
- hdrs = ["temp_umask.h"],
-)
-
-cc_library(
name = "unix_domain_socket_test_util",
testonly = 1,
srcs = ["unix_domain_socket_test_util.cc"],
@@ -608,7 +603,10 @@ cc_binary(
cc_binary(
name = "exceptions_test",
testonly = 1,
- srcs = ["exceptions.cc"],
+ srcs = select_arch(
+ amd64 = ["exceptions.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
gtest,
@@ -1140,11 +1138,11 @@ cc_binary(
srcs = ["mkdir.cc"],
linkstatic = 1,
deps = [
- ":temp_umask",
"//test/util:capability_util",
"//test/util:fs_util",
gtest,
"//test/util:temp_path",
+ "//test/util:temp_umask",
"//test/util:test_main",
"//test/util:test_util",
],
@@ -1299,12 +1297,12 @@ cc_binary(
srcs = ["open_create.cc"],
linkstatic = 1,
deps = [
- ":temp_umask",
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
gtest,
"//test/util:temp_path",
+ "//test/util:temp_umask",
"//test/util:test_main",
"//test/util:test_util",
],
@@ -1475,7 +1473,10 @@ cc_binary(
cc_binary(
name = "arch_prctl_test",
testonly = 1,
- srcs = ["arch_prctl.cc"],
+ srcs = select_arch(
+ amd64 = ["arch_prctl.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
@@ -1631,6 +1632,19 @@ cc_binary(
)
cc_binary(
+ name = "proc_pid_oomscore_test",
+ testonly = 1,
+ srcs = ["proc_pid_oomscore.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:fs_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
name = "proc_pid_smaps_test",
testonly = 1,
srcs = ["proc_pid_smaps.cc"],
@@ -3322,7 +3336,10 @@ cc_binary(
cc_binary(
name = "sysret_test",
testonly = 1,
- srcs = ["sysret.cc"],
+ srcs = select_arch(
+ amd64 = ["sysret.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
gtest,
@@ -3460,6 +3477,18 @@ cc_binary(
],
)
+cc_binary(
+ name = "tuntap_hostinet_test",
+ testonly = 1,
+ srcs = ["tuntap_hostinet.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
cc_library(
name = "udp_socket_test_cases",
testonly = 1,
@@ -3678,11 +3707,10 @@ cc_binary(
":socket_test_util",
gtest,
"//test/util:capability_util",
- "//test/util:memory_util",
+ "//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
],
)
diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc
index adfb149df..a26fc6af3 100644
--- a/test/syscalls/linux/bad.cc
+++ b/test/syscalls/linux/bad.cc
@@ -28,7 +28,7 @@ namespace {
constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms;
#elif __aarch64__
// Use the last of arch_specific_syscalls which are not implemented on arm64.
-constexpr uint32_t kNotImplementedSyscall = SYS_arch_specific_syscall + 15;
+constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15;
#endif
TEST(BadSyscallTest, NotImplemented) {
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
index 4e473268c..4dd302eed 100644
--- a/test/syscalls/linux/dev.cc
+++ b/test/syscalls/linux/dev.cc
@@ -153,13 +153,6 @@ TEST(DevTest, TTYExists) {
EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
}
-TEST(DevTest, NetTunExists) {
- struct stat statbuf = {};
- ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds());
- // Check that it's a character device with rw-rw-rw- permissions.
- EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
-}
-
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc
index cf138d328..def4c50a4 100644
--- a/test/syscalls/linux/mkdir.cc
+++ b/test/syscalls/linux/mkdir.cc
@@ -18,10 +18,10 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/temp_umask.h"
#include "test/util/capability_util.h"
#include "test/util/fs_util.h"
#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
#include "test/util/test_util.h"
namespace gvisor {
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
index 6ea48c263..133fdecf0 100644
--- a/test/syscalls/linux/network_namespace.cc
+++ b/test/syscalls/linux/network_namespace.cc
@@ -20,102 +20,33 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include "absl/synchronization/notification.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/capability_util.h"
-#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
-
namespace {
-using TestFunc = std::function<PosixError()>;
-using RunFunc = std::function<PosixError(TestFunc)>;
-
-struct NamespaceStrategy {
- RunFunc run;
-
- static NamespaceStrategy Of(RunFunc run) {
- NamespaceStrategy s;
- s.run = run;
- return s;
- }
-};
+TEST(NetworkNamespaceTest, LoopbackExists) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
-PosixError RunWithUnshare(TestFunc fn) {
- PosixError err = PosixError(-1, "function did not return a value");
ScopedThread t([&] {
- if (unshare(CLONE_NEWNET) != 0) {
- err = PosixError(errno);
- return;
- }
- err = fn();
- });
- t.Join();
- return err;
-}
+ ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0));
-PosixError RunWithClone(TestFunc fn) {
- struct Args {
- absl::Notification n;
- TestFunc fn;
- PosixError err;
- };
- Args args;
- args.fn = fn;
- args.err = PosixError(-1, "function did not return a value");
-
- ASSIGN_OR_RETURN_ERRNO(
- Mapping child_stack,
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- pid_t child = clone(
- +[](void *arg) {
- Args *args = reinterpret_cast<Args *>(arg);
- args->err = args->fn();
- args->n.Notify();
- syscall(SYS_exit, 0); // Exit manually. No return address on stack.
- return 0;
- },
- reinterpret_cast<void *>(child_stack.addr() + kPageSize),
- CLONE_NEWNET | CLONE_THREAD | CLONE_SIGHAND | CLONE_VM, &args);
- if (child < 0) {
- return PosixError(errno, "clone() failed");
- }
- args.n.WaitForNotification();
- return args.err;
-}
-
-class NetworkNamespaceTest
- : public ::testing::TestWithParam<NamespaceStrategy> {};
-
-TEST_P(NetworkNamespaceTest, LoopbackExists) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
-
- EXPECT_NO_ERRNO(GetParam().run([]() {
// TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
// Check loopback device exists.
int sock = socket(AF_INET, SOCK_DGRAM, 0);
- if (sock < 0) {
- return PosixError(errno, "socket() failed");
- }
+ ASSERT_THAT(sock, SyscallSucceeds());
struct ifreq ifr;
- snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
- if (ioctl(sock, SIOCGIFINDEX, &ifr) < 0) {
- return PosixError(errno, "ioctl() failed, lo cannot be found");
- }
- return NoError();
- }));
+ strncpy(ifr.ifr_name, "lo", IFNAMSIZ);
+ EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds())
+ << "lo cannot be found";
+ });
}
-INSTANTIATE_TEST_SUITE_P(
- AllNetworkNamespaceTest, NetworkNamespaceTest,
- ::testing::Values(NamespaceStrategy::Of(RunWithUnshare),
- NamespaceStrategy::Of(RunWithClone)));
-
} // namespace
-
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 902d0a0dc..51eacf3f2 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -19,11 +19,11 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/temp_umask.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
#include "test/util/test_util.h"
namespace gvisor {
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 92ae55eec..5ac68feb4 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <arpa/inet.h>
+#include <ifaddrs.h>
#include <linux/capability.h>
#include <linux/if_arp.h>
#include <linux/if_packet.h>
@@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() {
return ifr.ifr_ifindex;
}
-// Receive via a packet socket.
-TEST_P(CookedPacketTest, Receive) {
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- SendUDPMessage(udp_sock.get());
-
+// Receive and verify the message via packet socket on interface.
+void ReceiveMessage(int sock, int ifindex) {
// Wait for the socket to become readable.
struct pollfd pfd = {};
- pfd.fd = socket_;
+ pfd.fd = sock;
pfd.events = POLLIN;
EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
@@ -182,9 +178,10 @@ TEST_P(CookedPacketTest, Receive) {
char buf[64];
struct sockaddr_ll src = {};
socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_len),
SyscallSucceedsWithValue(packet_size));
+
// sockaddr_ll ends with an 8 byte physical address field, but ethernet
// addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
// here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
@@ -194,7 +191,7 @@ TEST_P(CookedPacketTest, Receive) {
// TODO(b/129292371): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
+ EXPECT_EQ(src.sll_ifindex, ifindex);
EXPECT_EQ(src.sll_halen, ETH_ALEN);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
@@ -222,6 +219,18 @@ TEST_P(CookedPacketTest, Receive) {
EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
}
+// Receive via a packet socket.
+TEST_P(CookedPacketTest, Receive) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ int loopback_index = GetLoopbackIndex();
+ ReceiveMessage(socket_, loopback_index);
+}
+
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
// TODO(b/129292371): Remove once we support packet socket writing.
@@ -313,6 +322,115 @@ TEST_P(CookedPacketTest, Send) {
EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
}
+// Bind and receive via packet socket.
+TEST_P(CookedPacketTest, BindReceive) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ ReceiveMessage(socket_, bind_addr.sll_ifindex);
+}
+
+// Double Bind socket.
+TEST_P(CookedPacketTest, DoubleBind) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Binding socket again should fail.
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched
+ // to EADDRINUSE.
+ AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL)));
+}
+
+// Bind and verify we do not receive data on interface which is not bound
+TEST_P(CookedPacketTest, BindDrop) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Bind to packet socket requires only family, protocol and ifindex.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = ifr.ifr_ifindex;
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Send to loopback interface.
+ struct sockaddr_in dest = {};
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket never receives any data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Bind with invalid address.
+TEST_P(CookedPacketTest, BindFail) {
+ // Null address.
+ ASSERT_THAT(
+ bind(socket_, nullptr, sizeof(struct sockaddr)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL)));
+
+ // Address of size 1.
+ uint8_t addr = 0;
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index f91187e75..5a70f6c3b 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -1431,6 +1431,12 @@ TEST(ProcPidFile, SubprocessRunning) {
EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
}
// Test whether /proc/PID/ files can be read for a zombie process.
@@ -1466,6 +1472,12 @@ TEST(ProcPidFile, SubprocessZombie) {
EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+ EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
//
@@ -1527,6 +1539,15 @@ TEST(ProcPidFile, SubprocessExited) {
EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
}
PosixError DirContainsImpl(absl::string_view path,
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 3a611a86f..4e23d1e78 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -33,6 +33,31 @@ namespace gvisor {
namespace testing {
namespace {
+constexpr const char kProcNet[] = "/proc/net";
+
+TEST(ProcNetSymlinkTarget, FileMode) {
+ struct stat s;
+ ASSERT_THAT(stat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFDIR);
+ EXPECT_EQ(s.st_mode & 0777, 0555);
+}
+
+TEST(ProcNetSymlink, FileMode) {
+ struct stat s;
+ ASSERT_THAT(lstat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFLNK);
+ EXPECT_EQ(s.st_mode & 0777, 0777);
+}
+
+TEST(ProcNetSymlink, Contents) {
+ char buf[40] = {};
+ int n = readlink(kProcNet, buf, sizeof(buf));
+ ASSERT_THAT(n, SyscallSucceeds());
+
+ buf[n] = 0;
+ EXPECT_STREQ(buf, "self/net");
+}
+
TEST(ProcNetIfInet6, Format) {
auto ifinet6 = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/if_inet6"));
EXPECT_THAT(ifinet6,
@@ -67,6 +92,59 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) {
EXPECT_EQ(buf, to_write);
}
+// DeviceEntry is an entry in /proc/net/dev
+struct DeviceEntry {
+ std::string name;
+ uint64_t stats[16];
+};
+
+PosixErrorOr<std::vector<DeviceEntry>> GetDeviceMetricsFromProc(
+ const std::string dev) {
+ std::vector<std::string> lines = absl::StrSplit(dev, '\n');
+ std::vector<DeviceEntry> entries;
+
+ // /proc/net/dev prints 2 lines of headers followed by a line of metrics for
+ // each network interface.
+ for (unsigned i = 2; i < lines.size(); i++) {
+ // Ignore empty lines.
+ if (lines[i].empty()) {
+ continue;
+ }
+
+ std::vector<std::string> values =
+ absl::StrSplit(lines[i], ' ', absl::SkipWhitespace());
+
+ // Interface name + 16 values.
+ if (values.size() != 17) {
+ return PosixError(EINVAL, "invalid line: " + lines[i]);
+ }
+
+ DeviceEntry entry;
+ entry.name = values[0];
+ // Skip the interface name and read only the values.
+ for (unsigned j = 1; j < 17; j++) {
+ uint64_t num;
+ if (!absl::SimpleAtoi(values[j], &num)) {
+ return PosixError(EINVAL, "invalid value: " + values[j]);
+ }
+ entry.stats[j - 1] = num;
+ }
+
+ entries.push_back(entry);
+ }
+
+ return entries;
+}
+
+// TEST(ProcNetDev, Format) tests that /proc/net/dev is parsable and
+// contains at least one entry.
+TEST(ProcNetDev, Format) {
+ auto dev = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/dev"));
+ auto entries = ASSERT_NO_ERRNO_AND_VALUE(GetDeviceMetricsFromProc(dev));
+
+ EXPECT_GT(entries.size(), 0);
+}
+
PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp,
const std::string& type,
const std::string& item) {
diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc
new file mode 100644
index 000000000..707821a3f
--- /dev/null
+++ b/test/syscalls/linux/proc_pid_oomscore.cc
@@ -0,0 +1,72 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+
+#include <exception>
+#include <iostream>
+#include <string>
+
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<int> ReadProcNumber(std::string path) {
+ ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path));
+ EXPECT_EQ(contents[contents.length() - 1], '\n');
+
+ int num;
+ if (!absl::SimpleAtoi(contents, &num)) {
+ return PosixError(EINVAL, "invalid value: " + contents);
+ }
+
+ return num;
+}
+
+TEST(ProcPidOomscoreTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score"));
+ EXPECT_LE(oom_score, 1000);
+ EXPECT_GE(oom_score, -1000);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+
+ // oom_score_adj defaults to 0.
+ EXPECT_EQ(oom_score, 0);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicWrite) {
+ constexpr int test_value = 7;
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY));
+ ASSERT_THAT(
+ RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1),
+ SyscallSucceeds());
+
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+ EXPECT_EQ(oom_score, test_value);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc
index cf6499f8b..ce88d90dd 100644
--- a/test/syscalls/linux/seccomp.cc
+++ b/test/syscalls/linux/seccomp.cc
@@ -53,7 +53,7 @@ namespace {
constexpr uint32_t kFilteredSyscall = SYS_vserver;
#elif __aarch64__
// Use the last of arch_specific_syscalls which are not implemented on arm64.
-constexpr uint32_t kFilteredSyscall = SYS_arch_specific_syscall + 15;
+constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15;
#endif
// Applies a seccomp-bpf filter that returns `filtered_result` for
@@ -70,20 +70,27 @@ void ApplySeccompFilter(uint32_t sysno, uint32_t filtered_result,
MaybeSave();
struct sock_filter filter[] = {
- // A = seccomp_data.arch
- BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4),
- // if (A != AUDIT_ARCH_X86_64) goto kill
- BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4),
- // A = seccomp_data.nr
- BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0),
- // if (A != sysno) goto allow
- BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1),
- // return filtered_result
- BPF_STMT(BPF_RET | BPF_K, filtered_result),
- // allow: return SECCOMP_RET_ALLOW
- BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
- // kill: return SECCOMP_RET_KILL
- BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL),
+ // A = seccomp_data.arch
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4),
+#if defined(__x86_64__)
+ // if (A != AUDIT_ARCH_X86_64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4),
+#elif defined(__aarch64__)
+ // if (A != AUDIT_ARCH_AARCH64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_AARCH64, 0, 4),
+#else
+#error "Unknown architecture"
+#endif
+ // A = seccomp_data.nr
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0),
+ // if (A != sysno) goto allow
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1),
+ // return filtered_result
+ BPF_STMT(BPF_RET | BPF_K, filtered_result),
+ // allow: return SECCOMP_RET_ALLOW
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
+ // kill: return SECCOMP_RET_KILL
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL),
};
struct sock_fprog prog;
prog.len = ABSL_ARRAYSIZE(filter);
@@ -179,9 +186,12 @@ TEST(SeccompTest, RetTrapCausesSIGSYS) {
TEST_CHECK(info->si_errno == kTrapValue);
TEST_CHECK(info->si_call_addr != nullptr);
TEST_CHECK(info->si_syscall == kFilteredSyscall);
-#ifdef __x86_64__
+#if defined(__x86_64__)
TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64);
TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == kFilteredSyscall);
+#elif defined(__aarch64__)
+ TEST_CHECK(info->si_arch == AUDIT_ARCH_AARCH64);
+ TEST_CHECK(uc->uc_mcontext.regs[8] == kFilteredSyscall);
#endif // defined(__x86_64__)
_exit(0);
});
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index e8f24a59e..f7d6139f1 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -447,6 +447,60 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST_P(AllSocketPairTest, SendTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 89, .tv_usec = 42000};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 89);
+ EXPECT_EQ(actual_tv.tv_usec, 42000);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval_with_extra {
+ struct timeval tv;
+ int64_t extra_data;
+ } ABSL_ATTRIBUTE_PACKED;
+
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 123000},
+ };
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &tv_extra, sizeof(tv_extra)),
+ SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 123000);
+}
+
TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -491,18 +545,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) {
ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf)));
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) {
+TEST_P(AllSocketPairTest, RecvTimeoutDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- struct timeval tv {
- .tv_sec = 0, .tv_usec = 35
- };
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetRecvTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 123, .tv_usec = 456000};
EXPECT_THAT(
setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 123);
+ EXPECT_EQ(actual_tv.tv_usec, 456000);
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
+TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct timeval_with_extra {
@@ -510,13 +582,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
int64_t extra_data;
} ABSL_ATTRIBUTE_PACKED;
- timeval_with_extra tv_extra;
- tv_extra.tv.tv_sec = 0;
- tv_extra.tv.tv_usec = 25;
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 432000},
+ };
EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
&tv_extra, sizeof(tv_extra)),
SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 432000);
}
TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) {
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index c951ac3b3..2503960f3 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -34,6 +34,13 @@
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+#ifndef AT_STATX_FORCE_SYNC
+#define AT_STATX_FORCE_SYNC 0x2000
+#endif
+#ifndef AT_STATX_DONT_SYNC
+#define AT_STATX_DONT_SYNC 0x4000
+#endif
+
namespace gvisor {
namespace testing {
@@ -607,7 +614,7 @@ int statx(int dirfd, const char* pathname, int flags, unsigned int mask,
}
TEST_F(StatTest, StatxAbsPath) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
struct kernel_statx stx;
@@ -617,7 +624,7 @@ TEST_F(StatTest, StatxAbsPath) {
}
TEST_F(StatTest, StatxRelPathDirFD) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
struct kernel_statx stx;
@@ -631,7 +638,7 @@ TEST_F(StatTest, StatxRelPathDirFD) {
}
TEST_F(StatTest, StatxRelPathCwd) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
@@ -643,7 +650,7 @@ TEST_F(StatTest, StatxRelPathCwd) {
}
TEST_F(StatTest, StatxEmptyPath) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
@@ -653,6 +660,60 @@ TEST_F(StatTest, StatxEmptyPath) {
EXPECT_TRUE(S_ISREG(stx.stx_mode));
}
+TEST_F(StatTest, StatxDoesNotRejectExtraneousMaskBits) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set all mask bits except for STATX__RESERVED.
+ uint mask = 0xffffffff & ~0x80000000;
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, mask, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxRejectsReservedMaskBit) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set STATX__RESERVED in the mask.
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, 0x80000000, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(StatTest, StatxSymlink) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ std::string parent_dir = "/tmp";
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(parent_dir, test_file_name_));
+ std::string p = link.path();
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), AT_SYMLINK_NOFOLLOW, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISLNK(stx.stx_mode));
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), 0, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxInvalidFlags) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), 12345, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Sync flags are mutually exclusive.
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(),
+ AT_STATX_FORCE_SYNC | AT_STATX_DONT_SYNC, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc
index 7e73325bf..92eec0449 100644
--- a/test/syscalls/linux/sticky.cc
+++ b/test/syscalls/linux/sticky.cc
@@ -42,8 +42,9 @@ TEST(StickyTest, StickyBitPermDenied) {
auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
- std::string path = JoinPath(dir.path(), "NewDir");
- ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY));
+ ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -61,7 +62,8 @@ TEST(StickyTest, StickyBitPermDenied) {
syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
SyscallSucceeds());
- EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR),
+ SyscallFailsWithErrno(EPERM));
});
}
@@ -96,8 +98,9 @@ TEST(StickyTest, StickyBitCapFOWNER) {
auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
- std::string path = JoinPath(dir.path(), "NewDir");
- ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY));
+ ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -114,7 +117,8 @@ TEST(StickyTest, StickyBitCapFOWNER) {
SyscallSucceeds());
EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true));
- EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR),
+ SyscallSucceeds());
});
}
} // namespace
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index c4591a3b9..d9c1ac0e1 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -143,6 +143,20 @@ TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) {
SyscallFailsWithErrno(EISCONN));
}
+TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+}
+
TEST_P(TcpSocketTest, DataCoalesced) {
char buf[10];
@@ -1349,6 +1363,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
SyscallFailsWithErrno(ENOTCONN));
}
+TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ addrlen);
+ int buf_sz = 1 << 18;
+ EXPECT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc
index 1ccb95733..e75bba669 100644
--- a/test/syscalls/linux/time.cc
+++ b/test/syscalls/linux/time.cc
@@ -26,6 +26,7 @@ namespace {
constexpr long kFudgeSeconds = 5;
+#if defined(__x86_64__) || defined(__i386__)
// Mimics the time(2) wrapper from glibc prior to 2.15.
time_t vsyscall_time(time_t* t) {
constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
@@ -98,6 +99,7 @@ TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) {
reinterpret_cast<struct timezone*>(0x1)),
::testing::KilledBySignal(SIGSEGV), "");
}
+#endif
} // namespace
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
index f6ac9d7b8..53ad2dda3 100644
--- a/test/syscalls/linux/tuntap.cc
+++ b/test/syscalls/linux/tuntap.cc
@@ -153,6 +153,13 @@ std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
} // namespace
+TEST(TuntapStaticTest, NetTunExists) {
+ struct stat statbuf;
+ ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
class TuntapTest : public ::testing::Test {
protected:
void TearDown() override {
@@ -249,50 +256,59 @@ TEST_F(TuntapTest, WriteToDownDevice) {
EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO));
}
-// This test sets up a TAP device and pings kernel by sending ICMP echo request.
-//
-// It works as the following:
-// * Open /dev/net/tun, and create kTapName interface.
-// * Use rtnetlink to do initial setup of the interface:
-// * Assign IP address 10.0.0.1/24 to kernel.
-// * MAC address: kMacA
-// * Bring up the interface.
-// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel.
-// * Loop to receive packets from TAP device/fd:
-// * If packet is an ICMP echo reply, it stops and passes the test.
-// * If packet is an ARP request, it responds with canned reply and resends
-// the
-// ICMP request packet.
-TEST_F(TuntapTest, PingKernel) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
-
+PosixErrorOr<FileDescriptor> OpenAndAttachTap(
+ const std::string& dev_name, const std::string& dev_ipv4_addr) {
// Interface creation.
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(kDevNetTun, O_RDWR));
struct ifreq ifr_set = {};
ifr_set.ifr_flags = IFF_TAP;
- strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
- EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
- SyscallSucceedsWithValue(0));
+ strncpy(ifr_set.ifr_name, dev_name.c_str(), IFNAMSIZ);
+ if (ioctl(fd.get(), TUNSETIFF, &ifr_set) < 0) {
+ return PosixError(errno);
+ }
- absl::optional<Link> link =
- ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName));
- ASSERT_TRUE(link.has_value());
+ ASSIGN_OR_RETURN_ERRNO(absl::optional<Link> link, GetLinkByName(dev_name));
+ if (!link.has_value()) {
+ return PosixError(ENOENT, "no link");
+ }
// Interface setup.
struct in_addr addr;
- inet_pton(AF_INET, "10.0.0.1", &addr);
+ inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr);
EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24,
&addr, sizeof(addr)));
if (!IsRunningOnGvisor()) {
// FIXME: gVisor doesn't support setting MAC address on interfaces yet.
- EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA)));
+ RETURN_IF_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA)));
// FIXME: gVisor always creates enabled/up'd interfaces.
- EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP));
+ RETURN_IF_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP));
}
+ return fd;
+}
+
+// This test sets up a TAP device and pings kernel by sending ICMP echo request.
+//
+// It works as the following:
+// * Open /dev/net/tun, and create kTapName interface.
+// * Use rtnetlink to do initial setup of the interface:
+// * Assign IP address 10.0.0.1/24 to kernel.
+// * MAC address: kMacA
+// * Bring up the interface.
+// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel.
+// * Loop to receive packets from TAP device/fd:
+// * If packet is an ICMP echo reply, it stops and passes the test.
+// * If packet is an ARP request, it responds with canned reply and resends
+// the
+// ICMP request packet.
+TEST_F(TuntapTest, PingKernel) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
@@ -342,5 +358,47 @@ TEST_F(TuntapTest, PingKernel) {
}
}
+TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+
+ // Send a UDP packet to remote.
+ int sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_IP);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_in remote = {};
+ remote.sin_family = AF_INET;
+ remote.sin_port = htons(42);
+ inet_pton(AF_INET, "10.0.0.2", &remote.sin_addr);
+ int ret = sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote),
+ sizeof(remote));
+ ASSERT_THAT(ret, ::testing::AnyOf(SyscallSucceeds(),
+ SyscallFailsWithErrno(EHOSTDOWN)));
+
+ struct inpkt {
+ union {
+ pihdr pi;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ break;
+ }
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc
new file mode 100644
index 000000000..1513fb9d5
--- /dev/null
+++ b/test/syscalls/linux/tuntap_hostinet.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(TuntapHostInetTest, NoNetTun) {
+ SKIP_IF(!IsRunningOnGvisor());
+ SKIP_IF(!IsRunningWithHostinet());
+
+ struct stat statbuf;
+ ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT));
+}
+
+} // namespace
+} // namespace testing
+
+} // namespace gvisor
diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc
index 2c2303358..ae4377108 100644
--- a/test/syscalls/linux/vsyscall.cc
+++ b/test/syscalls/linux/vsyscall.cc
@@ -24,6 +24,7 @@ namespace testing {
namespace {
+#if defined(__x86_64__) || defined(__i386__)
time_t vsyscall_time(time_t* t) {
constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t);
@@ -37,6 +38,7 @@ TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) {
time_t t;
EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds());
}
+#endif
} // namespace
diff --git a/test/util/BUILD b/test/util/BUILD
index 8b5a0f25c..2a17c33ee 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -350,3 +350,9 @@ cc_library(
":save_util",
],
)
+
+cc_library(
+ name = "temp_umask",
+ testonly = 1,
+ hdrs = ["temp_umask.h"],
+)
diff --git a/test/syscalls/linux/temp_umask.h b/test/util/temp_umask.h
index 81a25440c..e7de84a54 100644
--- a/test/syscalls/linux/temp_umask.h
+++ b/test/util/temp_umask.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
-#define GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
+#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_
+#define GVISOR_TEST_UTIL_TEMP_UMASK_H_
#include <sys/stat.h>
#include <sys/types.h>
@@ -36,4 +36,4 @@ class TempUmask {
} // namespace testing
} // namespace gvisor
-#endif // GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
+#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_
diff --git a/tools/BUILD b/tools/BUILD
index e73a9c885..ba3506c04 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -1,3 +1,3 @@
package(licenses = ["notice"])
-exports_files(["nogo.js"])
+exports_files(["nogo.json"])
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index 905b16d41..0a74370a6 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -2,15 +2,15 @@
load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier")
load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library")
-load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library")
+load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library")
load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test")
load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image")
load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image")
load("@pydeps//:requirements.bzl", _py_requirement = "requirement")
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library")
container_image = _container_image
-cc_binary = _cc_binary
cc_library = _cc_library
cc_flags_supplier = _cc_flags_supplier
cc_proto_library = _cc_proto_library
@@ -19,16 +19,77 @@ cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
go_image = _go_image
go_embed_data = _go_embed_data
gtest = "@com_google_googletest//:gtest"
+grpcpp = "@com_github_grpc_grpc//:grpc++"
gbenchmark = "@com_google_benchmark//:benchmark"
loopback = "//tools/bazeldefs:loopback"
-proto_library = native.proto_library
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
py_library = native.py_library
py_binary = native.py_binary
py_test = native.py_test
+def proto_library(name, has_services = None, **kwargs):
+ native.proto_library(
+ name = name,
+ **kwargs
+ )
+
+def cc_grpc_library(name, **kwargs):
+ _cc_grpc_library(name = name, grpc_only = True, **kwargs)
+
+def _go_proto_or_grpc_library(go_library_func, name, **kwargs):
+ deps = [
+ dep.replace("_proto", "_go_proto")
+ for dep in (kwargs.pop("deps", []) or [])
+ ]
+ go_library_func(
+ name = name + "_go_proto",
+ importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto",
+ proto = ":" + name + "_proto",
+ deps = deps,
+ **kwargs
+ )
+
+def go_proto_library(name, **kwargs):
+ _go_proto_or_grpc_library(_go_proto_library, name, **kwargs)
+
+def go_grpc_and_proto_libraries(name, **kwargs):
+ _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
+
+def cc_binary(name, static = False, **kwargs):
+ """Run cc_binary.
+
+ Args:
+ name: name of the target.
+ static: make a static binary if True
+ **kwargs: the rest of the args.
+ """
+ if static:
+ # How to statically link a c++ program that uses threads, like for gRPC:
+ # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html
+ if "linkopts" not in kwargs:
+ kwargs["linkopts"] = []
+ kwargs["linkopts"] += [
+ "-static",
+ "-lstdc++",
+ "-Wl,--whole-archive",
+ "-lpthread",
+ "-Wl,--no-whole-archive",
+ ]
+ _cc_binary(
+ name = name,
+ **kwargs
+ )
+
def go_binary(name, static = False, pure = False, **kwargs):
+ """Build a go binary.
+
+ Args:
+ name: name of the target.
+ static: build a static binary.
+ pure: build without cgo.
+ **kwargs: rest of the arguments are passed to _go_binary.
+ """
if static:
kwargs["static"] = "on"
if pure:
@@ -52,18 +113,17 @@ def go_tool_library(name, **kwargs):
**kwargs
)
-def go_proto_library(name, proto, **kwargs):
- deps = kwargs.pop("deps", [])
- _go_proto_library(
- name = name,
- importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name,
- proto = proto,
- deps = [dep.replace("_proto", "_go_proto") for dep in deps],
- **kwargs
- )
+def go_test(name, pure = False, library = None, **kwargs):
+ """Build a go test.
-def go_test(name, **kwargs):
- library = kwargs.pop("library", None)
+ Args:
+ name: name of the output binary.
+ pure: should it be built without cgo.
+ library: the library to embed.
+ **kwargs: rest of the arguments to pass to _go_test.
+ """
+ if pure:
+ kwargs["pure"] = "on"
if library:
kwargs["embed"] = [library]
_go_test(
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 15a310403..91d689a82 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -7,36 +7,39 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
-load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
load("//tools/bazeldefs:tags.bzl", "go_suffixes")
# Delegate directly.
cc_binary = _cc_binary
+cc_flags_supplier = _cc_flags_supplier
+cc_grpc_library = _cc_grpc_library
cc_library = _cc_library
cc_test = _cc_test
cc_toolchain = _cc_toolchain
-cc_flags_supplier = _cc_flags_supplier
container_image = _container_image
+default_installer = _default_installer
+default_net_util = _default_net_util
+gbenchmark = _gbenchmark
go_embed_data = _go_embed_data
go_image = _go_image
go_test = _go_test
go_tool_library = _go_tool_library
gtest = _gtest
-gbenchmark = _gbenchmark
+grpcpp = _grpcpp
+loopback = _loopback
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
-py_library = _py_library
py_binary = _py_binary
-py_test = _py_test
+py_library = _py_library
py_requirement = _py_requirement
+py_test = _py_test
select_arch = _select_arch
select_system = _select_system
-loopback = _loopback
-default_installer = _default_installer
-default_net_util = _default_net_util
-platforms = _platforms
+
default_platform = _default_platform
+platforms = _platforms
def go_binary(name, **kwargs):
"""Wraps the standard go_binary.
@@ -190,33 +193,52 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
**kwargs
)
-def proto_library(name, srcs, **kwargs):
+def proto_library(name, srcs, deps = None, has_services = 0, **kwargs):
"""Wraps the standard proto_library.
- Given a proto_library named "foo", this produces three different targets:
+ Given a proto_library named "foo", this produces up to five different
+ targets:
- foo_proto: proto_library rule.
- foo_go_proto: go_proto_library rule.
- foo_cc_proto: cc_proto_library rule.
+ - foo_go_grpc_proto: go_grpc_library rule.
+ - foo_cc_grpc_proto: cc_grpc_library rule.
Args:
+ name: the name to which _proto, _go_proto, etc, will be appended.
srcs: the proto sources.
+ deps: for the proto library and the go_proto_library.
+ has_services: 1 to build gRPC code, otherwise 0.
**kwargs: standard proto_library arguments.
"""
- deps = kwargs.pop("deps", [])
_proto_library(
name = name + "_proto",
srcs = srcs,
deps = deps,
+ has_services = has_services,
**kwargs
)
- _go_proto_library(
- name = name + "_go_proto",
- proto = ":" + name + "_proto",
- deps = deps,
- **kwargs
- )
+ if has_services:
+ _go_grpc_and_proto_libraries(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
+ else:
+ _go_proto_library(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
_cc_proto_library(
name = name + "_cc_proto",
deps = [":" + name + "_proto"],
**kwargs
)
+ if has_services:
+ _cc_grpc_library(
+ name = name + "_cc_grpc_proto",
+ srcs = [":" + name + "_proto"],
+ deps = [":" + name + "_cc_proto"],
+ **kwargs
+ )
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index b5d5a4487..44cb33ae4 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -7,6 +7,9 @@ go_library(
srcs = [
"generator.go",
"generator_interfaces.go",
+ "generator_interfaces_array_newtype.go",
+ "generator_interfaces_primitive_newtype.go",
+ "generator_interfaces_struct.go",
"generator_tests.go",
"util.go",
],
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index d365a1f3c..82983804c 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -235,6 +235,10 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a
debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
types = append(types, t)
continue
+ case *ast.ArrayType: // Newtype on array.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
}
// A user specifically requested marshalling on this type, but we
// don't support it.
@@ -281,17 +285,20 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface
i := newInterfaceGenerator(t, fset)
switch ty := t.Type.(type) {
case *ast.StructType:
- i.validateStruct()
- i.emitMarshallableForStruct()
- return i
+ i.validateStruct(t, ty)
+ i.emitMarshallableForStruct(ty)
case *ast.Ident:
i.validatePrimitiveNewtype(ty)
- i.emitMarshallableForPrimitiveNewtype()
- return i
+ i.emitMarshallableForPrimitiveNewtype(ty)
+ case *ast.ArrayType:
+ i.validateArrayNewtype(t.Name, ty)
+ // After validate, we can safely call arrayLen.
+ i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty))
default:
// This should've been filtered out by collectMarshallabeTypes.
panic(fmt.Sprintf("Unexpected type %+v", ty))
}
+ return i
}
// generateOneTestSuite generates a test suite for the automatically generated
@@ -406,7 +413,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
// empty example instead.
if len(ts) == 0 {
b.reset()
- b.emit("func ExampleEmptyTestSuite() {\n")
+ b.emit("func Example() {\n")
b.inIndent(func() {
b.emit("// This example is intentionally empty to ensure this file contains at least\n")
b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index ea1af998e..8babf61d2 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -15,10 +15,8 @@
package gomarshal
import (
- "fmt"
"go/ast"
"go/token"
- "strings"
)
// interfaceGenerator generates marshalling interfaces for a single type.
@@ -81,18 +79,6 @@ func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
g.as[fieldName] = struct{}{}
}
-func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
-func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
- return fmt.Sprintf("%s.%s", g.r, n.Name)
-}
-
// abortAt aborts the go_marshal tool with the given error message, with a
// reference position to the input source. Same as abortAt, but uses g to
// resolve p to position.
@@ -100,71 +86,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we provide
- // suggestions for some disallowed types and reject them, then attempt
- // to marshal any remaining types by invoking the marshal.Marshallable
- // interface on them. If these types don't actually implement
- // marshal.Marshallable, compilation of the generated code will fail
- // with an appropriate error message.
- return
- case "int":
- g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
-}
-
-// validateStruct ensures the type we're working with can be marshalled. These
-// checks are done ahead of time and in one place so we can make assumptions
-// later.
-func (g *interfaceGenerator) validateStruct() {
- g.forEachField(func(f *ast.Field) {
- if len(f.Names) == 0 {
- g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
- }
- })
-
- g.forEachField(func(f *ast.Field) {
- fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- g.validatePrimitiveNewtype(t)
- },
- selector: func(_, _, _ *ast.Ident) {
- // No validation to perform on selector fields. However this
- // callback must still be provided.
- },
- array: func(n, _ *ast.Ident, len int) {
- a := f.Type.(*ast.ArrayType)
- if a.Len == nil {
- g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
- }
-
- if _, ok := a.Len.(*ast.BasicLit); !ok {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
- }
-
- if _, ok := a.Elt.(*ast.Ident); !ok {
- g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
- }
-
- if len <= 0 {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
- }
- },
- unhandled: func(_ *ast.Ident) {
- g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
- },
- }.dispatch(f)
- })
-}
-
// scalarSize returns the size of type identified by t. If t isn't a primitive
// type, the size isn't known at code generation time, and must be resolved via
// the marshal.Marshallable interface.
@@ -191,8 +112,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
-func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
+// marshalScalar writes a single scalar to a byte slice.
+func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -215,9 +136,8 @@ func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar stri
}
}
-// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
-// byte slice.
-func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
+// unmarshalScalar reads a single scalar from a byte slice.
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
switch typ {
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
@@ -243,580 +163,3 @@ func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar st
g.recordPotentiallyNonPackedField(accessor)
}
}
-
-// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
-func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
- switch typ {
- case "int8", "uint8", "byte":
- g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
- case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
- case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
- case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
- default:
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
-}
-
-// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
-func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
- switch typ {
- case "byte":
- g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
- case "int8", "uint8":
- g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
- case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
- case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
-
- case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
- default:
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
-}
-
-// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
-// packed. Returns "", false if g.t has no fields that may be potentially
-// packed, otherwise returns <clause>, true, where <clause> is an expression
-// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
-func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
- if len(g.as) == 0 {
- return "", false
- }
-
- cs := make([]string, 0, len(g.as))
- for accessor, _ := range g.as {
- cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
- }
- return strings.Join(cs, " && "), true
-}
-
-func (g *interfaceGenerator) emitMarshallableForStruct() {
- // Is g.t a packed struct without consideing field types?
- thisPacked := true
- g.forEachField(func(f *ast.Field) {
- if f.Tag != nil {
- if f.Tag.Value == "`marshal:\"unaligned\"`" {
- if thisPacked {
- debugfAt(g.f.Position(g.t.Pos()),
- fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
- thisPacked = false
- }
- }
- }
- })
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
- }
- },
- selector: func(n, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(n, t *ast.Ident, len int) {
- if len < 1 {
- // Zero-length arrays should've been rejected by validate().
- panic("unreachable")
- }
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size * len
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
- g.inIndent(func() {
- g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
- g.inIndent(func() {
- g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- expr, fieldsMaybePacked := g.areFieldsPackedExpression()
- switch {
- case !thisPacked:
- g.emit("return false\n")
- case fieldsMaybePacked:
- g.emit("return %s\n", expr)
- default:
- g.emit("return true\n")
-
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("return err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("if err != nil {\n")
- g.inIndent(func() {
- g.emit("return err\n")
- })
- g.emit("}\n")
-
- g.emit("%s.UnmarshalBytes(buf)\n", g.r)
- g.emit("return nil\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast deserialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.recordUsedImport("io")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("n, err := w.Write(buf)\n")
- g.emit("return int64(n), err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-}
-
-// emitMarshallableForPrimitiveNewtype outputs code to implement the
-// marshal.Marshallable interface for a newtype on a primitive. Primitive
-// newtypes are always packed, so we can omit the various fallbacks required for
-// non-packed structs.
-func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
- g.recordUsedImport("io")
- g.recordUsedImport("marshal")
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- g.recordUsedImport("usermem")
-
- nt := g.t.Type.(*ast.Ident)
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- if size, dynamic := g.scalarSize(nt); !dynamic {
- g.emit("return %d\n", size)
- } else {
- g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Scalar newtypes are always packed.\n")
- g.emit("return true\n")
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
-
- })
- g.emit("}\n\n")
-
-}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
new file mode 100644
index 000000000..da36d9305
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -0,0 +1,183 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on arrays.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) {
+ if a.Len == nil {
+ g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Len.(*ast.BasicLit); !ok {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions"))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+
+ if arrayLen(a) <= 0 {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
+ }
+}
+
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(elt); !dynamic {
+ g.emit("return %d\n", size*len)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Array newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
new file mode 100644
index 000000000..159397825
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -0,0 +1,229 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on primitives.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+// marshalPrimitiveScalar writes a single primitive variable to a byte
+// slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
new file mode 100644
index 000000000..e66a38b2e
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -0,0 +1,450 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// structs.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "strings"
+)
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
+ forEachStructField(st, func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ forEachStructField(st, func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ g.validatePrimitiveNewtype(t)
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n, _ *ast.Ident, len int) {
+ g.validateArrayNewtype(n, f.Type.(*ast.ArrayType))
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ // Is g.t a packed struct without consideing field types?
+ thisPacked := true
+ forEachStructField(st, func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if thisPacked {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ thisPacked = false
+ }
+ }
+ }
+ })
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n, t *ast.Ident, len int) {
+ if len < 1 {
+ // Zero-length arrays should've been rejected by validate().
+ panic("unreachable")
+ }
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size * len
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("return err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("if err != nil {\n")
+ g.inIndent(func() {
+ g.emit("return err\n")
+ })
+ g.emit("}\n")
+
+ g.emit("%s.UnmarshalBytes(buf)\n", g.r)
+ g.emit("return nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("n, err := w.Write(buf)\n")
+ g.emit("return int64(n), err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 8ba47eb67..fd992e44a 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -164,7 +164,7 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
g.inIndent(func() {
- g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)")
+ g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
})
g.emit("}\n")
})
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index e2bca4e7c..a0936e013 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -64,6 +64,12 @@ func kindString(e ast.Expr) string {
}
}
+func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
// fieldDispatcher is a collection of callbacks for handling different types of
// fields in a struct declaration.
type fieldDispatcher struct {
@@ -73,6 +79,25 @@ type fieldDispatcher struct {
unhandled func(n *ast.Ident)
}
+// Precondition: a must have a literal for the array length. Consts and
+// expressions are not allowed as array lengths, and should be rejected by the
+// caller.
+func arrayLen(a *ast.ArrayType) int {
+ if a.Len == nil {
+ // Probably a slice? Must be handled by caller.
+ panic("Nil array length in array type")
+ }
+ lenLit, ok := a.Len.(*ast.BasicLit)
+ if !ok {
+ panic("Array has non-literal for length")
+ }
+ len, err := strconv.Atoi(lenLit.Value)
+ if err != nil {
+ panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err))
+ }
+ return len
+}
+
// Precondition: All dispatch callbacks that will be invoked must be
// provided. Embedded fields are not allowed, len(f.Names) >= 1.
func (fd fieldDispatcher) dispatch(f *ast.Field) {
@@ -96,22 +121,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.SelectorExpr:
fd.selector(name, v.X.(*ast.Ident), v.Sel)
case *ast.ArrayType:
- len := 0
- if v.Len != nil {
- // Non-literal array length is handled by generatorInterfaces.validate().
- if lenLit, ok := v.Len.(*ast.BasicLit); ok {
- var err error
- len, err = strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(err)
- }
- }
- }
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, len)
+ fd.array(name, t, arrayLen(v))
default:
- fd.array(name, nil, len)
+ // Should be handled with a better error message during validate.
+ panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
}
default:
fd.unhandled(name)
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 93229dedb..c829db6da 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -104,6 +104,11 @@ type Stat struct {
_ [3]int64
}
+// InetAddr is an example marshallable newtype on an array.
+//
+// +marshal
+type InetAddr [4]byte
+
// SignalSet is an example marshallable newtype on a primitive.
//
// +marshal
diff --git a/tools/go_mod.sh b/tools/go_mod.sh
new file mode 100755
index 000000000..84b779d6d
--- /dev/null
+++ b/tools/go_mod.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+# Build the :gopath target.
+bazel build //:gopath
+declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/"
+
+# Copy go.mod and execute the command.
+cp -a go.mod go.sum "${gopathdir}"
+(cd "${gopathdir}" && go mod "$@")
+cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" .
+
+# Cleanup the WORKSPACE file.
+bazel run //:gazelle -- update-repos -from_file=go.mod
diff --git a/tools/images/BUILD b/tools/images/BUILD
index fe11f08a3..66ffd02aa 100644
--- a/tools/images/BUILD
+++ b/tools/images/BUILD
@@ -9,7 +9,7 @@ package(
genrule(
name = "zone",
outs = ["zone.txt"],
- cmd = "gcloud config get-value compute/zone > $@",
+ cmd = "gcloud config get-value compute/zone > \"$@\"",
tags = [
"local",
"manual",
diff --git a/tools/images/ubuntu1604/10_core.sh b/tools/images/ubuntu1604/10_core.sh
index 46dda6bb1..cd518d6ac 100755
--- a/tools/images/ubuntu1604/10_core.sh
+++ b/tools/images/ubuntu1604/10_core.sh
@@ -17,7 +17,20 @@
set -xeo pipefail
# Install all essential build tools.
-apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config
+while true; do
+ if (apt-get update && apt-get install -y \
+ make \
+ git-core \
+ build-essential \
+ linux-headers-$(uname -r) \
+ pkg-config); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install a recent go toolchain.
if ! [[ -d /usr/local/go ]]; then
diff --git a/tools/images/ubuntu1604/20_bazel.sh b/tools/images/ubuntu1604/20_bazel.sh
index b33e1656c..bb7afa676 100755
--- a/tools/images/ubuntu1604/20_bazel.sh
+++ b/tools/images/ubuntu1604/20_bazel.sh
@@ -19,7 +19,17 @@ set -xeo pipefail
declare -r BAZEL_VERSION=2.0.0
# Install bazel dependencies.
-apt-get update && apt-get install -y openjdk-8-jdk-headless unzip
+while true; do
+ if (apt-get update && apt-get install -y \
+ openjdk-8-jdk-headless \
+ unzip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Use the release installer.
curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/tools/images/ubuntu1604/25_docker.sh b/tools/images/ubuntu1604/25_docker.sh
index 1d3defcd3..11eea2d72 100755
--- a/tools/images/ubuntu1604/25_docker.sh
+++ b/tools/images/ubuntu1604/25_docker.sh
@@ -15,12 +15,20 @@
# limitations under the License.
# Add dependencies.
-apt-get update && apt-get -y install \
- apt-transport-https \
- ca-certificates \
- curl \
- gnupg-agent \
- software-properties-common
+while true; do
+ if (apt-get update && apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install the key.
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
@@ -32,4 +40,15 @@ add-apt-repository \
stable"
# Install docker.
-apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io
+while true; do
+ if (apt-get update && apt-get install -y \
+ docker-ce \
+ docker-ce-cli \
+ containerd.io); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
diff --git a/tools/images/ubuntu1604/30_containerd.sh b/tools/images/ubuntu1604/30_containerd.sh
index a7472bd1c..fb3699c12 100755
--- a/tools/images/ubuntu1604/30_containerd.sh
+++ b/tools/images/ubuntu1604/30_containerd.sh
@@ -34,7 +34,17 @@ install_helper() {
}
# Install dependencies for the crictl tests.
-apt-get install -y btrfs-tools libseccomp-dev
+while true; do
+ if (apt-get update && apt-get install -y \
+ btrfs-tools \
+ libseccomp-dev); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install containerd & cri-tools.
GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
diff --git a/tools/images/ubuntu1604/40_kokoro.sh b/tools/images/ubuntu1604/40_kokoro.sh
index 5f2dfc858..06a1e6c48 100755
--- a/tools/images/ubuntu1604/40_kokoro.sh
+++ b/tools/images/ubuntu1604/40_kokoro.sh
@@ -23,7 +23,22 @@ declare -r ssh_public_keys=(
)
# Install dependencies.
-apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm python-pip python3-pip zip
+while true; do
+ if (apt-get update && apt-get install -y \
+ rsync \
+ coreutils \
+ python-psutil \
+ qemu-kvm \
+ python-pip \
+ python3-pip \
+ zip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# junitparser is used to merge junit xml files.
pip install junitparser
diff --git a/tools/installers/master.sh b/tools/installers/master.sh
index 52f9734a6..2c6001c6c 100755
--- a/tools/installers/master.sh
+++ b/tools/installers/master.sh
@@ -19,17 +19,16 @@ set -e
curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+
while true; do
- if apt-get update; then
- apt-get install -y runsc
+ if (apt-get update && apt-get install -y runsc); then
break
fi
result=$?
- # Check if apt update failed to aquire the file lock.
if [[ $result -ne 100 ]]; then
exit $result
fi
done
+
runsc install
service docker restart
-
diff --git a/tools/nogo.js b/tools/nogo.js
deleted file mode 100644
index fc0a4d1f0..000000000
--- a/tools/nogo.js
+++ /dev/null
@@ -1,7 +0,0 @@
-{
- "checkunsafe": {
- "exclude_files": {
- "/external/": "not subject to constraint"
- }
- }
-}
diff --git a/tools/nogo.json b/tools/nogo.json
new file mode 100644
index 000000000..ff369be6f
--- /dev/null
+++ b/tools/nogo.json
@@ -0,0 +1,95 @@
+{
+ "assign": {
+ "exclude_files": {
+ "/external/bazel_gazelle/walk/walk.go": "allowed: false positive"
+ }
+ },
+ "checkunsafe": {
+ "exclude_files": {
+ "/external/": "allowed: not subject to unsafe naming rules"
+ }
+ },
+ "copylocks": {
+ "exclude_files": {
+ ".*_state_autogen.go": "fix: m.Failf copies by value",
+ "/pkg/log/json.go": "fix: Emit passes lock by value: gvisor.dev/gvisor/pkg/log.JSONEmitter contains gvisor.dev/gvisor/pkg/log.Writer contains gvisor.dev/gvisor/pkg/sync.Mutex",
+ "/pkg/log/log_test.go": "fix: call of fmt.Printf copies lock value: gvisor.dev/gvisor/pkg/log.Writer contains gvisor.dev/gvisor/pkg/sync.Mutex",
+ "/pkg/sentry/fs/host/socket_test.go": "fix: call of t.Errorf copies lock value: gvisor.dev/gvisor/pkg/sentry/fs/host.ConnectedEndpoint contains gvisor.dev/gvisor/pkg/refs.AtomicRefCount contains gvisor.dev/gvisor/pkg/sync.Mutex",
+ "/pkg/sentry/fs/proc/sys_net.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/proc.tcpMemInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex",
+ "/pkg/sentry/fs/proc/sys_net.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/proc.tcpSack contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex",
+ "/pkg/sentry/fs/tty/slave.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/tty.slaveInodeOperations contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex",
+ "/pkg/sentry/kernel/time/time.go": "fix: Readiness passes lock by value: gvisor.dev/gvisor/pkg/sentry/kernel/time.ClockEventsQueue contains gvisor.dev/gvisor/pkg/waiter.Queue contains gvisor.dev/gvisor/pkg/sync.RWMutex",
+ "/pkg/sentry/kernel/syscalls_state.go": "fix: assignment copies lock value to *s: gvisor.dev/gvisor/pkg/sentry/kernel.SyscallTable contains gvisor.dev/gvisor/pkg/sentry/kernel.SyscallFlagsTable contains gvisor.dev/gvisor/pkg/sync.Mutex"
+ }
+ },
+ "lostcancel": {
+ "exclude_files": {
+ "/pkg/tcpip/network/arp/arp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak",
+ "/pkg/tcpip/stack/ndp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak",
+ "/pkg/tcpip/transport/udp/udp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak",
+ "/pkg/tcpip/transport/tcp/testing/context/context.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak"
+ }
+ },
+ "nilness": {
+ "exclude_files": {
+ "/com_github_vishvananda_netlink/route_linux.go": "allowed: false positive",
+ "/external/bazel_gazelle/cmd/gazelle/.*": "allowed: false positive",
+ "/org_golang_x_tools/go/packages/golist.go": "allowed: runtime internals",
+ "/pkg/sentry/platform/kvm/kvm_test.go": "allowed: intentional"
+ }
+ },
+ "printf": {
+ "exclude_files": {
+ ".*_abi_autogen_test.go": "fix: Sprintf format has insufficient args",
+ "/pkg/segment/test/segment_test.go": "fix: Errorf format %d arg seg.Start is a func value, not called",
+ "/pkg/tcpip/tcpip_test.go": "fix: Error call has possible formatting directive %q",
+ "/pkg/tcpip/header/eth_test.go": "fix: Fatalf format %s reads arg #3, but call has 2 args",
+ "/pkg/tcpip/header/ndp_test.go": "fix: Errorf format %d reads arg #1, but call has 0 args",
+ "/pkg/eventchannel/event_test.go": "fix: Fatal call has possible formatting directive %v",
+ "/pkg/tcpip/stack/ndp.go": "fix: Fatalf format %s has arg protocolAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
+ "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %s has arg flags of wrong type gvisor.dev/gvisor/pkg/sentry/fs.FileFlags",
+ "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %d arg f.FD is a func value, not called",
+ "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call",
+ "/pkg/tcpip/transport/udp/udp_test.go": "fix: Fatalf format %s has arg h.srcAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.FullAddress",
+ "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Fatalf format %s has arg tcpTW of wrong type gvisor.dev/gvisor/pkg/tcpip.TCPTimeWaitTimeoutOption",
+ "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Errorf call needs 1 arg but has 2 args",
+ "/pkg/tcpip/stack/ndp_test.go": "fix: Errorf format %s reads arg #3, but call has 2 args",
+ "/pkg/tcpip/stack/ndp_test.go": "fix: Fatalf format %s reads arg #5, but call has 4 args",
+ "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg protoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
+ "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic1ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
+ "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic2ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
+ "/pkg/tcpip/stack/stack_test.go": "fix: Fatal call has possible formatting directive %t",
+ "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf call has arguments but no formatting directives",
+ "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call",
+ "/pkg/sentry/fsimpl/tmpfs/stat_test.go": "fix: Errorf format %v reads arg #1, but call has 0 args",
+ "/runsc/container/test_app/test_app.go": "fix: Fatal call has possible formatting directive %q",
+ "/test/root/cgroup_test.go": "fix: Errorf format %s has arg gots of wrong type []int",
+ "/test/root/cgroup_test.go": "fix: Fatalf format %v reads arg #3, but call has 2 args",
+ "/test/runtimes/runner.go": "fix: Skip call has possible formatting directive %q",
+ "/test/runtimes/blacklist_test.go": "fix: Errorf format %q has arg blacklistFile of wrong type *string"
+ }
+ },
+ "structtag": {
+ "exclude_files": {
+ "/external/": "allowed: may use arbitrary tags"
+ }
+ },
+ "unsafeptr": {
+ "exclude_files": {
+ ".*_test.go": "allowed: exclude tests",
+ "/pkg/flipcall/flipcall_unsafe.go": "allowed: special case",
+ "/pkg/gohacks/gohacks_unsafe.go": "allowed: special case",
+ "/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go": "allowed: special case",
+ "/pkg/sentry/platform/kvm/(bluepill|machine)_unsafe.go": "allowed: special case",
+ "/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go": "allowed: special case",
+ "/pkg/sentry/platform/safecopy/safecopy_unsafe.go": "allowed: special case",
+ "/pkg/sentry/vfs/mount_unsafe.go": "allowed: special case"
+ }
+ },
+ "unusedresult": {
+ "exclude_files": {
+ "/pkg/sentry/fsimpl/proc/task_net.go": "fix: result of fmt.Sprintf call not used",
+ "/pkg/sentry/fsimpl/proc/tasks_net.go": "fix: result of fmt.Sprintf call not used"
+ }
+ }
+}
diff --git a/tools/tag_release.sh b/tools/tag_release.sh
index f33b902d6..4dbfe420a 100755
--- a/tools/tag_release.sh
+++ b/tools/tag_release.sh
@@ -21,13 +21,19 @@
set -xeu
# Check arguments.
-if [ "$#" -ne 2 ]; then
- echo "usage: $0 <commit|revid> <release.rc>"
+if [ "$#" -ne 3 ]; then
+ echo "usage: $0 <commit|revid> <release.rc> <message-file>"
exit 1
fi
declare -r target_commit="$1"
declare -r release="$2"
+declare -r message_file="$3"
+
+if ! [[ -r "${message_file}" ]]; then
+ echo "error: message file '${message_file}' is not readable."
+ exit 1
+fi
closest_commit() {
while read line; do
@@ -64,6 +70,6 @@ fi
# Tag the given commit (annotated, to record the committer).
declare -r tag="release-${release}"
-(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \
+(git tag -F "${message_file}" -a "${tag}" "${commit}" && \
git push origin tag "${tag}") || \
(git tag -d "${tag}" && false)
diff --git a/vdso/vdso.cc b/vdso/vdso.cc
index 8bb80a7a4..c2585d592 100644
--- a/vdso/vdso.cc
+++ b/vdso/vdso.cc
@@ -126,6 +126,10 @@ extern "C" int __kernel_clock_getres(clockid_t clock, struct timespec* res) {
case CLOCK_REALTIME:
case CLOCK_MONOTONIC:
case CLOCK_BOOTTIME: {
+ if (res == nullptr) {
+ return 0;
+ }
+
res->tv_sec = 0;
res->tv_nsec = 1;
break;